Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix open_mfdataset() dropping time encoding attrs #309

Merged
merged 6 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,38 @@ def setUp(self, tmp_path):
self.file_path1 = f"{dir}/file1.nc"
self.file_path2 = f"{dir}/file2.nc"

def test_mfdataset_keeps_time_encoding_dict(self):
ds1 = generate_dataset(cf_compliant=True, has_bounds=True)
ds1.to_netcdf(self.file_path1)

# Create another dataset that extends the time coordinates by 1 value,
# to mimic a multifile dataset.
ds2 = generate_dataset(cf_compliant=True, has_bounds=True)
ds2 = ds2.isel(dict(time=slice(0, 1)))
ds2["time"].values[:] = np.array(
["2002-01-16T12:00:00.000000000"],
dtype="datetime64[ns]",
)
ds2.to_netcdf(self.file_path2)

result = open_mfdataset([self.file_path1, self.file_path2], decode_times=True)
expected = ds1.merge(ds2)

assert result.identical(expected)

# We mainly care for the "source" and "original_shape" attrs (updated
# internally by xCDAT), and the "calendar" and "units" attrs. We don't
# perform equality assertion on the entire time `.encoding` dict because
# there might be different encoding attributes added or removed between
# xarray versions (e.g., "bzip2", "ztsd", "blosc", and "szip" are added
# in v2022.06.0), which makes that assertion fragile.
paths = result.time.encoding["source"]
assert self.file_path1 in paths[0]
assert self.file_path2 in paths[1]
assert result.time.encoding["original_shape"] == (16,)
assert result.time.encoding["calendar"] == "standard"
assert result.time.encoding["units"] == "days since 2000-01-01"

def test_non_cf_compliant_time_is_not_decoded(self):
ds1 = generate_dataset(cf_compliant=False, has_bounds=True)
ds1.to_netcdf(self.file_path1)
Expand Down
139 changes: 102 additions & 37 deletions xcdat/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from functools import partial
from glob import glob
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple, Union

import numpy as np
import xarray as xr
Expand All @@ -14,14 +14,24 @@

from xcdat import bounds # noqa: F401
from xcdat.axis import center_times as center_times_func
from xcdat.axis import get_axis_coord, swap_lon_axis
from xcdat.axis import get_axis_coord, get_axis_dim, swap_lon_axis
from xcdat.logger import setup_custom_logger

logger = setup_custom_logger(__name__)

#: List of non-CF compliant time units.
NON_CF_TIME_UNITS: List[str] = ["months", "years"]

# Type annotation for the `paths` arg.
Paths = Union[
str,
pathlib.Path,
List[str],
List[pathlib.Path],
List[List[str]],
List[List[pathlib.Path]],
]
Comment on lines +25 to +33
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted Paths type annotation since it was reused in multiple functions and making the function param definitions long.

I noticed xarray refactored the paths type annotation here: https://github.com/pydata/xarray/blob/f8fee902360f2330ab8c002d54480d357365c172/xarray/backends/api.py#L734

paths: str | NestedSequence[str | os.PathLike],



def open_dataset(
path: str,
Expand Down Expand Up @@ -87,10 +97,9 @@ def open_dataset(
"""
if decode_times:
cf_compliant_time: Optional[bool] = _has_cf_compliant_time(path)
# xCDAT attempts to decode non-CF compliant time coordinates.
if cf_compliant_time is False:
# XCDAT handles decoding time values with non-CF units.
ds = xr.open_dataset(path, decode_times=False, **kwargs)
# attempt to decode non-cf-compliant time axis
ds = decode_non_cf_time(ds)
else:
ds = xr.open_dataset(path, decode_times=True, **kwargs)
Expand All @@ -103,14 +112,7 @@ def open_dataset(


def open_mfdataset(
paths: Union[
str,
pathlib.Path,
List[str],
List[pathlib.Path],
List[List[str]],
List[List[pathlib.Path]],
],
paths: Paths,
data_var: Optional[str] = None,
add_bounds: bool = True,
decode_times: bool = True,
Expand Down Expand Up @@ -201,10 +203,19 @@ def open_mfdataset(

.. [2] https://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html
"""
# `xr.open_mfdataset()` drops the time coordinates encoding dictionary if
# multiple files are merged with `decode_times=True` (refer to
# https://github.com/pydata/xarray/issues/2436). The workaround is to store
# the time encoding from the first dataset as a variable, and add the time
# encoding back to final merged dataset.
time_encoding = None

if decode_times:
time_encoding = _keep_time_encoding(paths)

cf_compliant_time: Optional[bool] = _has_cf_compliant_time(paths)
# XCDAT handles decoding time values with non-CF units using the
# preprocess kwarg.
# xCDAT attempts to decode non-CF compliant time coordinates using the
# preprocess keyword arg with `xr.open_mfdataset()`.
if cf_compliant_time is False:
decode_times = False
preprocess = partial(_preprocess_non_cf_dataset, callable=preprocess)
Expand All @@ -218,6 +229,12 @@ def open_mfdataset(
)
ds = _postprocess_dataset(ds, data_var, center_times, add_bounds, lon_orient)

if time_encoding is not None:
time_dim = get_axis_dim(ds, "T")
ds[time_dim].encoding = time_encoding
# Update "original_shape" to reflect the final time coordinates shape.
ds[time_dim].encoding["original_shape"] = ds[time_dim].shape
Comment on lines +235 to +236
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the "original_shape" of the final merged time coordinates.


return ds


Expand Down Expand Up @@ -405,16 +422,43 @@ def decode_non_cf_time(dataset: xr.Dataset) -> xr.Dataset:
return ds


def _has_cf_compliant_time(
path: Union[
str,
pathlib.Path,
List[str],
List[pathlib.Path],
List[List[str]],
List[List[pathlib.Path]],
]
) -> Optional[bool]:
def _keep_time_encoding(paths: Paths) -> Dict[Hashable, Any]:
"""
Returns the time encoding attributes from the first dataset in a list of
paths.

Time encoding information is critical for several xCDAT operations such as
temporal averaging (e.g., uses the "calendar" attr). This function is a
workaround to the undesired xarray behavior/quirk with
`xr.open_mfdataset()`, which drops the `.encoding` dict from the final
merged dataset (refer to https://github.com/pydata/xarray/issues/2436).

Parameters
----------
paths: Paths
The paths to the dataset(s).

Returns
-------
Dict[Hashable, Any]
The time encoding dictionary.
"""
first_path = _get_first_path(paths)

# xcdat.open_dataset() is called instead of xr.open_dataset() because
# we want to handle decoding non-CF compliant as well.
# FIXME: Remove `type: ignore` comment after properly handling the type
# annotations in `_get_first_path()`.
ds = open_dataset(first_path, decode_times=True, add_bounds=False) # type: ignore
time_coord = get_axis_coord(ds, "T")

time_encoding = time_coord.encoding
time_encoding["source"] = paths
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set "source" to the paths arg.


return time_coord.encoding


def _has_cf_compliant_time(paths: Paths) -> Optional[bool]:
"""Checks if a dataset has time coordinates with CF compliant units.

If the dataset does not contain a time dimension, None is returned.
Expand Down Expand Up @@ -444,19 +488,8 @@ def _has_cf_compliant_time(
performance because it is slower to combine all files then check for CF
compliance.
"""
first_file: Optional[Union[pathlib.Path, str]] = None

if isinstance(path, str) and "*" in path:
first_file = glob(path)[0]
elif isinstance(path, str) or isinstance(path, pathlib.Path):
first_file = path
elif isinstance(path, list):
if any(isinstance(sublist, list) for sublist in path):
first_file = path[0][0] # type: ignore
else:
first_file = path[0] # type: ignore

ds = xr.open_dataset(first_file, decode_times=False)
first_path = _get_first_path(paths)
ds = xr.open_dataset(first_path, decode_times=False)

if ds.cf.dims.get("T") is None:
return None
Expand All @@ -474,6 +507,38 @@ def _has_cf_compliant_time(
return cf_compliant


def _get_first_path(path: Paths) -> Optional[Union[pathlib.Path, str]]:
"""Returns the first path from a list of paths.

Parameters
----------
path : Paths
A list of paths.

Returns
-------
str
Returns the first path from a list of paths.
"""
# FIXME: This function should throw an exception if the first file
# is not a supported type.
# FIXME: The `type: ignore` comments should be removed after properly
# handling the types.
Comment on lines +523 to +526
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a few FIXME to _get_first_path(). The first FIXME is a bit tricky based on the nesting of elements, so I deferred it to later.

first_file: Optional[Union[pathlib.Path, str]] = None

if isinstance(path, str) and "*" in path:
first_file = glob(path)[0]
elif isinstance(path, str) or isinstance(path, pathlib.Path):
first_file = path
elif isinstance(path, list):
if any(isinstance(sublist, list) for sublist in path):
first_file = path[0][0] # type: ignore
else:
first_file = path[0] # type: ignore

return first_file
Comment on lines +510 to +539
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted _get_first_path() to be reused in other functions (_keep_time_encoding() and _has_cf_compliant_time()).



def _postprocess_dataset(
dataset: xr.Dataset,
data_var: Optional[str] = None,
Expand Down