Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve chunks in CF Writer #1254

Merged
merged 8 commits into from Sep 18, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
111 changes: 60 additions & 51 deletions satpy/tests/writer_tests/test_cf.py
Expand Up @@ -499,6 +499,7 @@ def test_da2cf(self):
attrs, attrs_expected, attrs_expected_flat = self.get_test_attrs()
attrs['area'] = 'some_area'
attrs['prerequisites'] = [make_dsq(name='hej')]
attrs['_satpy_id_name'] = 'myname'

# Adjust expected attributes
expected_prereq = ("DataQuery(name='hej')")
Expand Down Expand Up @@ -1005,57 +1006,6 @@ def test_global_attr_history_and_Conventions(self):
self.assertIn('TEST add history\n', f.attrs['history'])
self.assertIn('Created by pytroll/satpy on', f.attrs['history'])

def test_update_encoding(self):
import xarray as xr
from satpy.writers.cf_writer import CFWriter

# Without time dimension
ds = xr.Dataset({'foo': (('y', 'x'), [[1, 2], [3, 4]]),
'bar': (('y', 'x'), [[3, 4], [5, 6]])},
coords={'y': [1, 2],
'x': [3, 4],
'lon': (('y', 'x'), [[7, 8], [9, 10]])})
ds = ds.chunk(2)
kwargs = {'encoding': {'bar': {'chunksizes': (1, 1)}},
'other': 'kwargs'}
enc, other_kwargs = CFWriter.update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (2, 2)},
'bar': {'chunksizes': (1, 1)}})
self.assertDictEqual(other_kwargs, {'other': 'kwargs'})

# Chunksize may not exceed shape
ds = ds.chunk(8)
kwargs = {'encoding': {}, 'other': 'kwargs'}
enc, other_kwargs = CFWriter.update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (2, 2)},
'bar': {'chunksizes': (2, 2)}})

# With time dimension
ds = ds.expand_dims({'time': [datetime(2009, 7, 1, 12, 15)]})
kwargs = {'encoding': {'bar': {'chunksizes': (1, 1, 1)}},
'other': 'kwargs'}
enc, other_kwargs = CFWriter.update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (1, 2, 2)},
'bar': {'chunksizes': (1, 1, 1)},
'time': {'_FillValue': None,
'calendar': 'proleptic_gregorian',
'units': 'days since 2009-07-01 12:15:00'},
'time_bnds': {'_FillValue': None,
'calendar': 'proleptic_gregorian',
'units': 'days since 2009-07-01 12:15:00'}})

# User-defined encoding may not be altered
self.assertDictEqual(kwargs['encoding'], {'bar': {'chunksizes': (1, 1, 1)}})


class TestCFWriterData(unittest.TestCase):
"""Test case for CF writer where data arrays are needed."""
Expand Down Expand Up @@ -1129,3 +1079,62 @@ def test_collect_datasets_with_latitude_named_lat(self, *mocks):
self.assertRaises(KeyError, getitem, datas['var1'], 'longitude')
self.assertEqual(datas2['var1']['latitude'].attrs['name'], 'latitude')
self.assertEqual(datas2['var1']['longitude'].attrs['name'], 'longitude')


class EncodingUpdateTest(unittest.TestCase):
"""Test update of netCDF encoding."""
def setUp(self):
import xarray as xr
self.ds = xr.Dataset({'foo': (('y', 'x'), [[1, 2], [3, 4]]),
'bar': (('y', 'x'), [[3, 4], [5, 6]])},
coords={'y': [1, 2],
'x': [3, 4],
'lon': (('y', 'x'), [[7, 8], [9, 10]])})

def test_without_time(self):
from satpy.writers.cf_writer import update_encoding

# Without time dimension
ds = self.ds.chunk(2)
kwargs = {'encoding': {'bar': {'chunksizes': (1, 1)}},
'other': 'kwargs'}
enc, other_kwargs = update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (2, 2)},
'bar': {'chunksizes': (1, 1)}})
self.assertDictEqual(other_kwargs, {'other': 'kwargs'})

# Chunksize may not exceed shape
ds = self.ds.chunk(8)
kwargs = {'encoding': {}, 'other': 'kwargs'}
enc, other_kwargs = update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (2, 2)},
'bar': {'chunksizes': (2, 2)}})

def test_with_time(self):
from satpy.writers.cf_writer import update_encoding

# With time dimension
ds = self.ds.chunk(8).expand_dims({'time': [datetime(2009, 7, 1, 12, 15)]})
kwargs = {'encoding': {'bar': {'chunksizes': (1, 1, 1)}},
'other': 'kwargs'}
enc, other_kwargs = update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (1, 2, 2)},
'bar': {'chunksizes': (1, 1, 1)},
'time': {'_FillValue': None,
'calendar': 'proleptic_gregorian',
'units': 'days since 2009-07-01 12:15:00'},
'time_bnds': {'_FillValue': None,
'calendar': 'proleptic_gregorian',
'units': 'days since 2009-07-01 12:15:00'}})

# User-defined encoding may not be altered
self.assertDictEqual(kwargs['encoding'], {'bar': {'chunksizes': (1, 1, 1)}})
120 changes: 70 additions & 50 deletions satpy/writers/cf_writer.py
Expand Up @@ -414,6 +414,62 @@ def encode_attrs_nc(attrs):
return OrderedDict(encoded_attrs)


def _set_default_chunks(encoding, dataset):
"""Update encoding to preserve current dask chunks.

Existing user-defined chunks take precedence.
"""
for var_name, variable in dataset.variables.items():
if variable.chunks:
chunks = tuple(
np.stack([variable.data.chunksize,
variable.shape]).min(axis=0)
) # Chunksize may not exceed shape
encoding.setdefault(var_name, {})
encoding[var_name].setdefault('chunksizes', chunks)


def update_encoding(dataset, to_netcdf_kwargs):
"""Update encoding.

Preserve dask chunks, avoid fill values in coordinate variables and make sure that
time & time bounds have the same units.
"""
other_to_netcdf_kwargs = to_netcdf_kwargs.copy()
encoding = other_to_netcdf_kwargs.pop('encoding', {}).copy()

# If not specified otherwise by the user, preserve current dask chunks.
_set_default_chunks(encoding, dataset)

# Avoid _FillValue attribute being added to coordinate variables
# (https://github.com/pydata/xarray/issues/1865).
coord_vars = []
for data_array in dataset.values():
coord_vars.extend(set(data_array.dims).intersection(data_array.coords))
for coord_var in coord_vars:
encoding.setdefault(coord_var, {})
encoding[coord_var].update({'_FillValue': None})
sfinkens marked this conversation as resolved.
Show resolved Hide resolved

# Make sure time coordinates and bounds have the same units. Default is xarray's CF datetime
# encoding, which can be overridden by user-defined encoding.
if 'time' in dataset:
try:
dtnp64 = dataset['time'].data[0]
except IndexError:
dtnp64 = dataset['time'].data

default = CFDatetimeCoder().encode(xr.DataArray(dtnp64))
time_enc = {'units': default.attrs['units'], 'calendar': default.attrs['calendar']}
time_enc.update(encoding.get('time', {}))
bounds_enc = {'units': time_enc['units'],
'calendar': time_enc['calendar'],
'_FillValue': None}
encoding['time'] = time_enc
encoding['time_bnds'] = bounds_enc # FUTURE: Not required anymore with xarray-0.14+
sfinkens marked this conversation as resolved.
Show resolved Hide resolved

return encoding, other_to_netcdf_kwargs


class CFWriter(Writer):
"""Writer producing NetCDF/CF compatible datasets."""

Expand All @@ -440,6 +496,11 @@ def da2cf(dataarray, epoch=EPOCH, flatten_attrs=False, exclude_attrs=None, compr
name = new_data.attrs.pop('name')
new_data = new_data.rename(name)

# Remove _satpy* attributes
satpy_attrs = [key for key in new_data.attrs if key.startswith('_satpy')]
for satpy_attr in satpy_attrs:
new_data.attrs.pop(satpy_attr)

# Remove area as well as user-defined attributes
for key in ['area'] + exclude_attrs:
new_data.attrs.pop(key, None)
Expand Down Expand Up @@ -475,7 +536,7 @@ def da2cf(dataarray, epoch=EPOCH, flatten_attrs=False, exclude_attrs=None, compr
new_data['y'].attrs['units'] = 'm'

if 'crs' in new_data.coords:
new_data = new_data.drop('crs')
new_data = new_data.drop_vars('crs')

if 'long_name' not in new_data.attrs and 'standard_name' not in new_data.attrs:
new_data.attrs['long_name'] = new_data.name
Expand All @@ -491,6 +552,13 @@ def da2cf(dataarray, epoch=EPOCH, flatten_attrs=False, exclude_attrs=None, compr

return new_data

@staticmethod
def update_encoding(dataset, to_netcdf_kwargs):
warnings.warn('CFWriter.update_encoding is deprecated. '
'Use satpy.writers.cf_writer.update_encoding instead.',
DeprecationWarning)
return update_encoding(dataset, to_netcdf_kwargs)

def save_dataset(self, dataset, filename=None, fill_value=None, **kwargs):
"""Save the *dataset* to a given *filename*."""
return self.save_datasets([dataset], filename, **kwargs)
Expand Down Expand Up @@ -530,54 +598,6 @@ def _collect_datasets(self, datasets, epoch=EPOCH, flatten_attrs=False, exclude_

return datas, start_times, end_times

@staticmethod
def update_encoding(dataset, to_netcdf_kwargs):
"""Update encoding.

Preserve chunk sizes, avoid fill values in coordinate variables and make sure that
time & time bounds have the same units.
"""
other_to_netcdf_kwargs = to_netcdf_kwargs.copy()
encoding = other_to_netcdf_kwargs.pop('encoding', {}).copy()

# If not specified otherwise by the user, preserve current chunks.
for var_name, variable in dataset.variables.items():
if variable.chunks:
chunks = tuple(
np.stack([variable.data.chunksize,
variable.shape]).min(axis=0)
) # Chunksize may not exceed shape
encoding.setdefault(var_name, {})
encoding[var_name].setdefault('chunksizes', chunks)

# Avoid _FillValue attribute being added to coordinate variables
# (https://github.com/pydata/xarray/issues/1865).
coord_vars = []
for data_array in dataset.values():
coord_vars.extend(set(data_array.dims).intersection(data_array.coords))
for coord_var in coord_vars:
encoding.setdefault(coord_var, {})
encoding[coord_var].update({'_FillValue': None})

# Make sure time coordinates and bounds have the same units. Default is xarray's CF datetime
# encoding, which can be overridden by user-defined encoding.
if 'time' in dataset:
try:
dtnp64 = dataset['time'].data[0]
except IndexError:
dtnp64 = dataset['time'].data

default = CFDatetimeCoder().encode(xr.DataArray(dtnp64))
time_enc = {'units': default.attrs['units'], 'calendar': default.attrs['calendar']}
time_enc.update(encoding.get('time', {}))
bounds_enc = {'units': time_enc['units'],
'calendar': time_enc['calendar'],
'_FillValue': None}
encoding['time'] = time_enc
encoding['time_bnds'] = bounds_enc # FUTURE: Not required anymore with xarray-0.14+

return encoding, other_to_netcdf_kwargs

def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None, engine=None, epoch=EPOCH,
flatten_attrs=False, exclude_attrs=None, include_lonlats=True, pretty=False,
compression=None, **to_netcdf_kwargs):
Expand Down Expand Up @@ -681,7 +701,7 @@ def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None,
grp_str = ' of group {}'.format(group_name) if group_name is not None else ''
logger.warning('No time dimension in datasets{}, skipping time bounds creation.'.format(grp_str))

encoding, other_to_netcdf_kwargs = self.update_encoding(dataset, to_netcdf_kwargs)
encoding, other_to_netcdf_kwargs = update_encoding(dataset, to_netcdf_kwargs)
res = dataset.to_netcdf(filename, engine=engine, group=group_name, mode='a', encoding=encoding,
**other_to_netcdf_kwargs)
written.append(res)
Expand Down