Skip to content

Commit

Permalink
Raise an informative error message when object array has mixed types (p…
Browse files Browse the repository at this point in the history
…ydata#4700)

Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 28, 2023
1 parent e7e8c38 commit dc0931a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
24 changes: 20 additions & 4 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,32 @@ def _var_as_tuple(var: Variable) -> T_VarTuple:
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()


def _infer_dtype(array, name: T_Name = None) -> np.dtype:
"""Given an object array with no missing values, infer its dtype from its
first element
"""
def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
raise TypeError("infer_type must be called on a dtype=object array")

if array.size == 0:
return np.dtype(float)

native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
raise ValueError(
"unable to infer dtype on variable {!r}; object array "
"contains mixed native types: {}".format(
name, ", ".join(x.__name__ for x in native_dtypes)
)
)

native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
raise ValueError(
"unable to infer dtype on variable {!r}; object array "
"contains mixed native types: {}".format(
name, ", ".join(x.__name__ for x in native_dtypes)
)
)

element = array[(0,) * array.ndim]
# We use the base types to avoid subclasses of bytes and str (which might
# not play nice with e.g. hdf5 datatypes), such as those from numpy
Expand Down
12 changes: 12 additions & 0 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,18 @@ def test_encoding_kwarg_fixed_width_string(self) -> None:
pass


@pytest.mark.parametrize(
"data",
[
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
np.array([["x", 1], ["y", 2]], dtype="object"),
],
)
def test_infer_dtype_error_on_mixed_types(data):
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
conventions._infer_dtype(data, "test")


class TestDecodeCFVariableWithArrayUnits:
def test_decode_cf_variable_with_array_units(self) -> None:
v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})
Expand Down

0 comments on commit dc0931a

Please sign in to comment.