From b48e0969670f17857a314b5a755b1a1bf7ee38df Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 24 May 2018 17:52:06 -0700 Subject: [PATCH] BUG: fix writing to groups with h5netcdf (#2181) * BUG: fix writing to groups with h5netcdf Fixes GH2177 Our test suite was inadvertently not checking this. * what's new note --- doc/whats-new.rst | 6 +++++- xarray/backends/h5netcdf_.py | 9 +++++++-- xarray/backends/netCDF4_.py | 10 +++++++--- xarray/tests/test_backends.py | 12 ++++++------ 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d9f43fa1868..4a01065bd70 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,7 +64,11 @@ Bug fixes longer falsely returns an empty array when the slice includes the value in the index) (:issue:`2165`). By `Spencer Clark `_. - + +- Fix Dataset.to_netcdf() cannot create group with engine="h5netcdf" + (:issue:`2177`). + By `Stephan Hoyer `_ + .. _whats-new.0.10.4: v0.10.4 (May 16, 2018) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index f9e2b3dece1..6b3cd9ebb15 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -12,7 +12,7 @@ HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root) from .netCDF4_ import ( BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, - _get_datatype, _nc4_group) + _get_datatype, _nc4_require_group) class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -57,11 +57,16 @@ def _read_attributes(h5netcdf_var): lsd_okay=False, h5py_okay=True, backend='h5netcdf') +def _h5netcdf_create_group(dataset, name): + return dataset.create_group(name) + + def _open_h5netcdf_group(filename, mode, group): import h5netcdf ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): - return _nc4_group(ds, group, mode) + return _nc4_require_group( + ds, group, mode, create_group=_h5netcdf_create_group) class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 14061a0fb08..5391a890fb3 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -108,7 +108,11 @@ def _nc4_dtype(var): return dtype -def _nc4_group(ds, group, mode): +def _netcdf4_create_group(dataset, name): + return dataset.createGroup(name) + + +def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): if group in set([None, '', '/']): # use the root group return ds @@ -123,7 +127,7 @@ def _nc4_group(ds, group, mode): ds = ds.groups[key] except KeyError as e: if mode != 'r': - ds = ds.createGroup(key) + ds = create_group(ds, key) else: # wrap error to provide slightly more helpful message raise IOError('group not found: %s' % key, e) @@ -210,7 +214,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs): ds = nc4.Dataset(filename, mode=mode, **kwargs) with close_on_error(ds): - ds = _nc4_group(ds, group, mode) + ds = _nc4_require_group(ds, group, mode) _disable_auto_decode_group(ds) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 513f5f0834e..0768a942a77 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -892,7 +892,7 @@ def test_open_group(self): # check equivalent ways to specify group for group in 'foo', '/foo', 'foo/', '/foo/': - with open_dataset(tmp_file, group=group) as actual: + with self.open(tmp_file, group=group) as actual: assert_equal(actual['x'], expected['x']) # check that missing group raises appropriate exception @@ -920,18 +920,18 @@ def test_open_subgroup(self): # check equivalent ways to specify group for group in 'foo/bar', '/foo/bar', 'foo/bar/', '/foo/bar/': - with open_dataset(tmp_file, group=group) as actual: + with self.open(tmp_file, group=group) as actual: assert_equal(actual['x'], expected['x']) def test_write_groups(self): data1 = create_test_data() data2 = data1 * 2 with create_tmp_file() as tmp_file: - data1.to_netcdf(tmp_file, group='data/1') - data2.to_netcdf(tmp_file, group='data/2', mode='a') - with open_dataset(tmp_file, group='data/1') as actual1: + self.save(data1, tmp_file, group='data/1') + self.save(data2, tmp_file, group='data/2', mode='a') + with self.open(tmp_file, group='data/1') as actual1: assert_identical(data1, actual1) - with open_dataset(tmp_file, group='data/2') as actual2: + with self.open(tmp_file, group='data/2') as actual2: assert_identical(data2, actual2) def test_roundtrip_string_with_fill_value_vlen(self):