Skip to content

Commit

Permalink
Merge pull request #1254 from sfinkens/cf-writer-chunks
Browse files Browse the repository at this point in the history
Preserve chunks in CF Writer
  • Loading branch information
mraspaud committed Sep 18, 2020
2 parents c5cadc5 + 9525cc9 commit 5063081
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 35 deletions.
60 changes: 60 additions & 0 deletions satpy/tests/writer_tests/test_cf.py
Expand Up @@ -499,6 +499,7 @@ def test_da2cf(self):
attrs, attrs_expected, attrs_expected_flat = self.get_test_attrs()
attrs['area'] = 'some_area'
attrs['prerequisites'] = [make_dsq(name='hej')]
attrs['_satpy_id_name'] = 'myname'

# Adjust expected attributes
expected_prereq = ("DataQuery(name='hej')")
Expand Down Expand Up @@ -1078,3 +1079,62 @@ def test_collect_datasets_with_latitude_named_lat(self, *mocks):
self.assertRaises(KeyError, getitem, datas['var1'], 'longitude')
self.assertEqual(datas2['var1']['latitude'].attrs['name'], 'latitude')
self.assertEqual(datas2['var1']['longitude'].attrs['name'], 'longitude')


class EncodingUpdateTest(unittest.TestCase):
"""Test update of netCDF encoding."""
def setUp(self):
import xarray as xr
self.ds = xr.Dataset({'foo': (('y', 'x'), [[1, 2], [3, 4]]),
'bar': (('y', 'x'), [[3, 4], [5, 6]])},
coords={'y': [1, 2],
'x': [3, 4],
'lon': (('y', 'x'), [[7, 8], [9, 10]])})

def test_without_time(self):
from satpy.writers.cf_writer import update_encoding

# Without time dimension
ds = self.ds.chunk(2)
kwargs = {'encoding': {'bar': {'chunksizes': (1, 1)}},
'other': 'kwargs'}
enc, other_kwargs = update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (2, 2)},
'bar': {'chunksizes': (1, 1)}})
self.assertDictEqual(other_kwargs, {'other': 'kwargs'})

# Chunksize may not exceed shape
ds = self.ds.chunk(8)
kwargs = {'encoding': {}, 'other': 'kwargs'}
enc, other_kwargs = update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (2, 2)},
'bar': {'chunksizes': (2, 2)}})

def test_with_time(self):
from satpy.writers.cf_writer import update_encoding

# With time dimension
ds = self.ds.chunk(8).expand_dims({'time': [datetime(2009, 7, 1, 12, 15)]})
kwargs = {'encoding': {'bar': {'chunksizes': (1, 1, 1)}},
'other': 'kwargs'}
enc, other_kwargs = update_encoding(ds, kwargs)
self.assertDictEqual(enc, {'y': {'_FillValue': None},
'x': {'_FillValue': None},
'lon': {'chunksizes': (2, 2)},
'foo': {'chunksizes': (1, 2, 2)},
'bar': {'chunksizes': (1, 1, 1)},
'time': {'_FillValue': None,
'calendar': 'proleptic_gregorian',
'units': 'days since 2009-07-01 12:15:00'},
'time_bnds': {'_FillValue': None,
'calendar': 'proleptic_gregorian',
'units': 'days since 2009-07-01 12:15:00'}})

# User-defined encoding may not be altered
self.assertDictEqual(kwargs['encoding'], {'bar': {'chunksizes': (1, 1, 1)}})
118 changes: 83 additions & 35 deletions satpy/writers/cf_writer.py
Expand Up @@ -99,6 +99,7 @@
"""

from collections import OrderedDict, defaultdict
import copy
import logging
from datetime import datetime
import json
Expand Down Expand Up @@ -413,6 +414,73 @@ def encode_attrs_nc(attrs):
return OrderedDict(encoded_attrs)


def _set_default_chunks(encoding, dataset):
"""Update encoding to preserve current dask chunks.
Existing user-defined chunks take precedence.
"""
for var_name, variable in dataset.variables.items():
if variable.chunks:
chunks = tuple(
np.stack([variable.data.chunksize,
variable.shape]).min(axis=0)
) # Chunksize may not exceed shape
encoding.setdefault(var_name, {})
encoding[var_name].setdefault('chunksizes', chunks)


def _set_default_fill_value(encoding, dataset):
"""Set default fill values.
Avoid _FillValue attribute being added to coordinate variables
(https://github.com/pydata/xarray/issues/1865).
"""
coord_vars = []
for data_array in dataset.values():
coord_vars.extend(set(data_array.dims).intersection(data_array.coords))
for coord_var in coord_vars:
encoding.setdefault(coord_var, {})
encoding[coord_var].update({'_FillValue': None})


def _set_default_time_encoding(encoding, dataset):
"""Set default time encoding.
Make sure time coordinates and bounds have the same units. Default is xarray's CF datetime
encoding, which can be overridden by user-defined encoding.
"""
if 'time' in dataset:
try:
dtnp64 = dataset['time'].data[0]
except IndexError:
dtnp64 = dataset['time'].data

default = CFDatetimeCoder().encode(xr.DataArray(dtnp64))
time_enc = {'units': default.attrs['units'], 'calendar': default.attrs['calendar']}
time_enc.update(encoding.get('time', {}))
bounds_enc = {'units': time_enc['units'],
'calendar': time_enc['calendar'],
'_FillValue': None}
encoding['time'] = time_enc
encoding['time_bnds'] = bounds_enc # FUTURE: Not required anymore with xarray-0.14+


def update_encoding(dataset, to_netcdf_kwargs):
"""Update encoding.
Preserve dask chunks, avoid fill values in coordinate variables and make sure that
time & time bounds have the same units.
"""
other_to_netcdf_kwargs = to_netcdf_kwargs.copy()
encoding = other_to_netcdf_kwargs.pop('encoding', {}).copy()

_set_default_chunks(encoding, dataset)
_set_default_fill_value(encoding, dataset)
_set_default_time_encoding(encoding, dataset)

return encoding, other_to_netcdf_kwargs


class CFWriter(Writer):
"""Writer producing NetCDF/CF compatible datasets."""

Expand All @@ -439,6 +507,11 @@ def da2cf(dataarray, epoch=EPOCH, flatten_attrs=False, exclude_attrs=None, compr
name = new_data.attrs.pop('name')
new_data = new_data.rename(name)

# Remove _satpy* attributes
satpy_attrs = [key for key in new_data.attrs if key.startswith('_satpy')]
for satpy_attr in satpy_attrs:
new_data.attrs.pop(satpy_attr)

# Remove area as well as user-defined attributes
for key in ['area'] + exclude_attrs:
new_data.attrs.pop(key, None)
Expand Down Expand Up @@ -474,7 +547,7 @@ def da2cf(dataarray, epoch=EPOCH, flatten_attrs=False, exclude_attrs=None, compr
new_data['y'].attrs['units'] = 'm'

if 'crs' in new_data.coords:
new_data = new_data.drop('crs')
new_data = new_data.drop_vars('crs')

if 'long_name' not in new_data.attrs and 'standard_name' not in new_data.attrs:
new_data.attrs['long_name'] = new_data.name
Expand All @@ -490,6 +563,13 @@ def da2cf(dataarray, epoch=EPOCH, flatten_attrs=False, exclude_attrs=None, compr

return new_data

@staticmethod
def update_encoding(dataset, to_netcdf_kwargs):
warnings.warn('CFWriter.update_encoding is deprecated. '
'Use satpy.writers.cf_writer.update_encoding instead.',
DeprecationWarning)
return update_encoding(dataset, to_netcdf_kwargs)

def save_dataset(self, dataset, filename=None, fill_value=None, **kwargs):
"""Save the *dataset* to a given *filename*."""
return self.save_datasets([dataset], filename, **kwargs)
Expand Down Expand Up @@ -529,39 +609,6 @@ def _collect_datasets(self, datasets, epoch=EPOCH, flatten_attrs=False, exclude_

return datas, start_times, end_times

def update_encoding(self, dataset, to_netcdf_kwargs):
"""Update encoding.
Avoid _FillValue attribute being added to coordinate variables (https://github.com/pydata/xarray/issues/1865).
"""
other_to_netcdf_kwargs = to_netcdf_kwargs.copy()
encoding = other_to_netcdf_kwargs.pop('encoding', {}).copy()
coord_vars = []
for data_array in dataset.values():
coord_vars.extend(set(data_array.dims).intersection(data_array.coords))
for coord_var in coord_vars:
encoding.setdefault(coord_var, {})
encoding[coord_var].update({'_FillValue': None})

# Make sure time coordinates and bounds have the same units. Default is xarray's CF datetime
# encoding, which can be overridden by user-defined encoding.
if 'time' in dataset:
try:
dtnp64 = dataset['time'].data[0]
except IndexError:
dtnp64 = dataset['time'].data

default = CFDatetimeCoder().encode(xr.DataArray(dtnp64))
time_enc = {'units': default.attrs['units'], 'calendar': default.attrs['calendar']}
time_enc.update(encoding.get('time', {}))
bounds_enc = {'units': time_enc['units'],
'calendar': time_enc['calendar'],
'_FillValue': None}
encoding['time'] = time_enc
encoding['time_bnds'] = bounds_enc # FUTURE: Not required anymore with xarray-0.14+

return encoding, other_to_netcdf_kwargs

def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None, engine=None, epoch=EPOCH,
flatten_attrs=False, exclude_attrs=None, include_lonlats=True, pretty=False,
compression=None, **to_netcdf_kwargs):
Expand Down Expand Up @@ -639,6 +686,7 @@ def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None,
root.attrs['Conventions'] = CF_VERSION

# Remove satpy-specific kwargs
to_netcdf_kwargs = copy.deepcopy(to_netcdf_kwargs) # may contain dictionaries (encoding)
satpy_kwargs = ['overlay', 'decorate', 'config_files']
for kwarg in satpy_kwargs:
to_netcdf_kwargs.pop(kwarg, None)
Expand All @@ -664,7 +712,7 @@ def save_datasets(self, datasets, filename=None, groups=None, header_attrs=None,
grp_str = ' of group {}'.format(group_name) if group_name is not None else ''
logger.warning('No time dimension in datasets{}, skipping time bounds creation.'.format(grp_str))

encoding, other_to_netcdf_kwargs = self.update_encoding(dataset, to_netcdf_kwargs)
encoding, other_to_netcdf_kwargs = update_encoding(dataset, to_netcdf_kwargs)
res = dataset.to_netcdf(filename, engine=engine, group=group_name, mode='a', encoding=encoding,
**other_to_netcdf_kwargs)
written.append(res)
Expand Down

0 comments on commit 5063081

Please sign in to comment.