Skip to content

Commit

Permalink
Use encoding['dtype'] over data.dtype when possible within CFMaskCode…
Browse files Browse the repository at this point in the history
…r.encode (#3652)

* Use encoding['dtype'] over data.dtype when possible

* Add what's new entry

* Fix typo in what's new

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
spencerkclark and dcherian committed Jan 15, 2020
1 parent e0fd480 commit 9959405
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ Bug fixes
By `Tom Augspurger <https://github.com/TomAugspurger>`_.
- Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error
when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser <https://github.com/mathause>`_.
- Fix regression in xarray 0.14.1 that prevented encoding times with certain
``dtype``, ``_FillValue``, and ``missing_value`` encodings (:issue:`3624`).
By `Spencer Clark <https://github.com/spencerkclark>`_
- Raise an error when trying to use :py:meth:`Dataset.rename_dims` to
rename to an existing name (:issue:`3438`, :pull:`3645`)
By `Justus Magin <https://github.com/keewis>`_.
Expand Down
5 changes: 3 additions & 2 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class CFMaskCoder(VariableCoder):
def encode(self, variable, name=None):
dims, data, attrs, encoding = unpack_for_encoding(variable)

dtype = np.dtype(encoding.get("dtype", data.dtype))
fv = encoding.get("_FillValue")
mv = encoding.get("missing_value")

Expand All @@ -162,14 +163,14 @@ def encode(self, variable, name=None):

if fv is not None:
# Ensure _FillValue is cast to same dtype as data's
encoding["_FillValue"] = data.dtype.type(fv)
encoding["_FillValue"] = dtype.type(fv)
fill_value = pop_to(encoding, attrs, "_FillValue", name=name)
if not pd.isnull(fill_value):
data = duck_array_ops.fillna(data, fill_value)

if mv is not None:
# Ensure missing_value is cast to same dtype as data's
encoding["missing_value"] = data.dtype.type(mv)
encoding["missing_value"] = dtype.type(mv)
fill_value = pop_to(encoding, attrs, "missing_value", name=name)
if not pd.isnull(fill_value) and fv is None:
data = duck_array_ops.fillna(data, fill_value)
Expand Down
36 changes: 27 additions & 9 deletions xarray/tests/test_coding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from contextlib import suppress

import numpy as np
import pandas as pd
import pytest

import xarray as xr
from xarray.coding import variables
from xarray.conventions import decode_cf_variable, encode_cf_variable

from . import assert_equal, assert_identical, requires_dask

Expand All @@ -20,20 +22,36 @@ def test_CFMaskCoder_decode():
assert_identical(expected, encoded)


def test_CFMaskCoder_encode_missing_fill_values_conflict():
original = xr.Variable(
("x",),
[0.0, -1.0, 1.0],
encoding={"_FillValue": np.float32(1e20), "missing_value": np.float64(1e20)},
)
coder = variables.CFMaskCoder()
encoded = coder.encode(original)
encoding_with_dtype = {
"dtype": np.dtype("float64"),
"_FillValue": np.float32(1e20),
"missing_value": np.float64(1e20),
}
encoding_without_dtype = {
"_FillValue": np.float32(1e20),
"missing_value": np.float64(1e20),
}
CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS = {
"numeric-with-dtype": ([0.0, -1.0, 1.0], encoding_with_dtype),
"numeric-without-dtype": ([0.0, -1.0, 1.0], encoding_without_dtype),
"times-with-dtype": (pd.date_range("2000", periods=3), encoding_with_dtype),
}


@pytest.mark.parametrize(
("data", "encoding"),
CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.values(),
ids=list(CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.keys()),
)
def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding):
original = xr.Variable(("x",), data, encoding=encoding)
encoded = encode_cf_variable(original)

assert encoded.dtype == encoded.attrs["missing_value"].dtype
assert encoded.dtype == encoded.attrs["_FillValue"].dtype

with pytest.warns(variables.SerializationWarning):
roundtripped = coder.decode(coder.encode(original))
roundtripped = decode_cf_variable("foo", encoded)
assert_identical(roundtripped, original)


Expand Down

0 comments on commit 9959405

Please sign in to comment.