Skip to content

Commit

Permalink
Merge pull request #1216 from gerritholl/combine-metadata-array-inter…
Browse files Browse the repository at this point in the history
…face

Make combine_arrays understand non-numpy arrays
  • Loading branch information
mraspaud committed May 28, 2020
2 parents 836c657 + f307d98 commit 2bf2e96
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 15 deletions.
69 changes: 54 additions & 15 deletions satpy/dataset.py
Expand Up @@ -17,14 +17,12 @@
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""Dataset objects."""

import sys
import logging
import numbers
from collections import namedtuple
from collections.abc import Collection
from datetime import datetime

import numpy as np

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -62,11 +60,13 @@ def average_datetimes(dt_list):
def combine_metadata(*metadata_objects, **kwargs):
"""Combine the metadata of two or more Datasets.
If any keys are not equal or do not exist in all provided dictionaries
then they are not included in the returned dictionary.
By default any keys with the word 'time' in them and consisting
of datetime objects will be averaged. This is to handle cases where
data were observed at almost the same time but not exactly.
If the values corresponding to any keys are not equal or do not
exist in all provided dictionaries then they are not included in
the returned dictionary. By default any keys with the word 'time'
in them and consisting of datetime objects will be averaged. This
is to handle cases where data were observed at almost the same time
but not exactly. In the interest of time, arrays are compared by
object identity rather than by their contents.
Args:
*metadata_objects: MetadataObject or dict objects to combine
Expand Down Expand Up @@ -98,18 +98,57 @@ def combine_metadata(*metadata_objects, **kwargs):
shared_info = {}
for k in shared_keys:
values = [nfo[k] for nfo in info_dicts]
any_arrays = any([isinstance(val, np.ndarray) for val in values])
if any_arrays:
if all(np.all(val == values[0]) for val in values[1:]):
if _share_metadata_key(k, values, average_times):
if 'time' in k and isinstance(values[0], datetime) and average_times:
shared_info[k] = average_datetimes(values)
else:
shared_info[k] = values[0]
elif 'time' in k and isinstance(values[0], datetime) and average_times:
shared_info[k] = average_datetimes(values)
elif all(val == values[0] for val in values[1:]):
shared_info[k] = values[0]

return shared_info


def _share_metadata_key(k, values, average_times):
"""Helper for combine_metadata, decide if key is shared."""
any_arrays = any([hasattr(val, "__array__") for val in values])
# in the real world, the `ancillary_variables` attribute may be
# List[xarray.DataArray], this means our values are now
# List[List[xarray.DataArray]].
# note that this list_of_arrays check is also true for any
# higher-dimensional ndarray, but we only use this check after we have
# checked any_arrays so this false positive should have no impact
list_of_arrays = any(
[isinstance(val, Collection) and len(val) > 0 and
all([hasattr(subval, "__array__")
for subval in val])
for val in values])
if any_arrays:
return _share_metadata_key_array(values)
elif list_of_arrays:
return _share_metadata_key_list_arrays(values)
elif 'time' in k and isinstance(values[0], datetime) and average_times:
return True
elif all(val == values[0] for val in values[1:]):
return True
return False


def _share_metadata_key_array(values):
"""Helper for combine_metadata, check object identity in list of arrays."""
for val in values[1:]:
if val is not values[0]:
return False
return True


def _share_metadata_key_list_arrays(values):
"""Helper for combine_metadata, check object identity in list of list of arrays."""
for val in values[1:]:
for arr, ref in zip(val, values[0]):
if arr is not ref:
return False
return True


DATASET_KEYS = ("name", "wavelength", "resolution", "polarization",
"calibration", "level", "modifiers")
DatasetID = namedtuple("DatasetID", " ".join(DATASET_KEYS))
Expand Down
36 changes: 36 additions & 0 deletions satpy/tests/test_dataset.py
Expand Up @@ -84,3 +84,39 @@ def test_combine_times(self):
ret = combine_metadata(*dts, average_times=False)
# times are not equal so don't include it in the final result
self.assertNotIn('start_time', ret)

def test_combine_arrays(self):
"""Test the combine_metadata with arrays."""
from satpy.dataset import combine_metadata
from numpy import arange, ones
from xarray import DataArray
dts = [
{"quality": (arange(25) % 2).reshape(5, 5).astype("?")},
{"quality": (arange(1, 26) % 3).reshape(5, 5).astype("?")},
{"quality": ones((5, 5,), "?")},
]
assert "quality" not in combine_metadata(*dts)
dts2 = [{"quality": DataArray(d["quality"])} for d in dts]
assert "quality" not in combine_metadata(*dts2)
# the ancillary_variables attribute is actually a list of data arrays
dts3 = [{"quality": [d["quality"]]} for d in dts]
assert "quality" not in combine_metadata(*dts3)
# check cases with repeated arrays
dts4 = [
{"quality": dts[0]["quality"]},
{"quality": dts[0]["quality"]},
]
assert "quality" in combine_metadata(*dts4)
dts5 = [
{"quality": dts3[0]["quality"]},
{"quality": dts3[0]["quality"]},
]
assert "quality" in combine_metadata(*dts5)
# check with other types
dts6 = [
DataArray(arange(5), attrs=dts[0]),
DataArray(arange(5), attrs=dts[0]),
DataArray(arange(5), attrs=dts[1]),
object()
]
assert "quality" not in combine_metadata(*dts6)

0 comments on commit 2bf2e96

Please sign in to comment.