Skip to content

Commit

Permalink
Merge pull request #2737 from pnuu/min-max-dataset-times
Browse files Browse the repository at this point in the history
Change `start_time` and `end_time` handling in `combine_metadata`
  • Loading branch information
mraspaud committed Feb 15, 2024
2 parents e2153ab + b8a47a9 commit 25d5357
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 110 deletions.
7 changes: 0 additions & 7 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,13 +1665,6 @@ def __call__(self, *args, **kwargs):
img.attrs["mode"] = "".join(img.bands.data)
img.attrs.pop("modifiers", None)
img.attrs.pop("calibration", None)
# Add start time if not present in the filename
if "start_time" not in img.attrs or not img.attrs["start_time"]:
import datetime as dt
img.attrs["start_time"] = dt.datetime.utcnow()
if "end_time" not in img.attrs or not img.attrs["end_time"]:
import datetime as dt
img.attrs["end_time"] = dt.datetime.utcnow()

return img

Expand Down
80 changes: 66 additions & 14 deletions satpy/dataset/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""Utilities for merging metadata from various sources."""

import warnings
from collections.abc import Collection
from datetime import datetime
from functools import partial, reduce
Expand All @@ -27,33 +28,50 @@
from satpy.writers.utils import flatten_dict


def combine_metadata(*metadata_objects, average_times=True):
def combine_metadata(*metadata_objects, average_times=None):
"""Combine the metadata of two or more Datasets.
If the values corresponding to any keys are not equal or do not
exist in all provided dictionaries then they are not included in
the returned dictionary. By default any keys with the word 'time'
in them and consisting of datetime objects will be averaged. This
is to handle cases where data were observed at almost the same time
but not exactly. In the interest of time, lazy arrays are compared by
object identity rather than by their contents.
the returned dictionary.
All values of the keys containing the substring 'start_time' will be set
to the earliest value and similarly for 'end_time' to latest time. All
other keys containing the word 'time' are averaged. Before these adjustments,
`None` values resulting from data that don't have times associated to them
are removed. These rules are applied also to values in the 'time_parameters'
dictionary.
.. versionchanged:: 0.47
Before Satpy 0.47, all times, including `start_time` and `end_time`, were averaged.
In the interest of processing time, lazy arrays are compared by object
identity rather than by their contents.
Args:
*metadata_objects: MetadataObject or dict objects to combine
average_times (bool): Average any keys with 'time' in the name
Kwargs:
average_times (bool): Removed option to average all time attributes.
Returns:
dict: the combined metadata
"""
info_dicts = _get_valid_dicts(metadata_objects)
if average_times is not None:
warnings.warn(
"'average_time' option has been removed and start/end times are handled with min/max instead.",
UserWarning
)

info_dicts = _get_valid_dicts(metadata_objects)
if len(info_dicts) == 1:
return info_dicts[0].copy()

shared_keys = _shared_keys(info_dicts)

return _combine_shared_info(shared_keys, info_dicts, average_times)
return _combine_shared_info(shared_keys, info_dicts)


def _get_valid_dicts(metadata_objects):
Expand All @@ -75,17 +93,51 @@ def _shared_keys(info_dicts):
return reduce(set.intersection, key_sets)


def _combine_shared_info(shared_keys, info_dicts, average_times):
def _combine_shared_info(shared_keys, info_dicts):
shared_info = {}
for key in shared_keys:
values = [info[key] for info in info_dicts]
if "time" in key and isinstance(values[0], datetime) and average_times:
shared_info[key] = average_datetimes(values)
elif _are_values_combinable(values):
shared_info[key] = values[0]
_combine_values(key, values, shared_info)
return shared_info


def _combine_values(key, values, shared_info):
if "time" in key:
times = _combine_times(key, values)
if times is not None:
shared_info[key] = times
elif _are_values_combinable(values):
shared_info[key] = values[0]


def _combine_times(key, values):
if key == "time_parameters":
return _combine_time_parameters(values)
filtered_values = _filter_time_values(values)
if not filtered_values:
return None
if "end_time" in key:
return max(filtered_values)
elif "start_time" in key:
return min(filtered_values)
return average_datetimes(filtered_values)


def _combine_time_parameters(values):
# Assume the first item has all the keys
keys = values[0].keys()
res = {}
for key in keys:
sub_values = [itm[key] for itm in values]
res[key] = _combine_times(key, sub_values)
return res


def _filter_time_values(values):
"""Remove values that are not datetime objects."""
return [v for v in values if isinstance(v, datetime)]


def average_datetimes(datetime_list):
"""Average a series of datetime objects.
Expand Down
41 changes: 8 additions & 33 deletions satpy/multiscene/_blend_funcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from datetime import datetime
from typing import Callable, Iterable, Mapping, Optional, Sequence

import pandas as pd
Expand All @@ -13,7 +12,6 @@
def stack(
data_arrays: Sequence[xr.DataArray],
weights: Optional[Sequence[xr.DataArray]] = None,
combine_times: bool = True,
blend_type: str = "select_with_weights"
) -> xr.DataArray:
"""Combine a series of datasets in different ways.
Expand All @@ -39,19 +37,18 @@ def stack(
"""
if weights:
return _stack_with_weights(data_arrays, weights, combine_times, blend_type)
return _stack_no_weights(data_arrays, combine_times)
return _stack_with_weights(data_arrays, weights, blend_type)
return _stack_no_weights(data_arrays)


def _stack_with_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool,
blend_type: str
) -> xr.DataArray:
blend_func = _get_weighted_blending_func(blend_type)
filled_weights = list(_fill_weights_for_invalid_dataset_pixels(datasets, weights))
return blend_func(datasets, filled_weights, combine_times)
return blend_func(datasets, filled_weights)


def _get_weighted_blending_func(blend_type: str) -> Callable:
Expand Down Expand Up @@ -84,10 +81,9 @@ def _fill_weights_for_invalid_dataset_pixels(
def _stack_blend_by_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
"""Stack datasets blending overlap using weights."""
attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets], combine_times)
attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets])

overlays = []
for weight, overlay in zip(weights, datasets):
Expand All @@ -109,14 +105,13 @@ def _stack_blend_by_weights(
def _stack_select_by_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
"""Stack datasets selecting pixels using weights."""
indices = da.argmax(da.dstack(weights), axis=-1)
if "bands" in datasets[0].dims:
indices = [indices] * datasets[0].sizes["bands"]

attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets], combine_times)
attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets])
dims = datasets[0].dims
coords = datasets[0].coords
selected_array = xr.DataArray(da.choose(indices, datasets), dims=dims, coords=coords, attrs=attrs)
Expand All @@ -125,7 +120,6 @@ def _stack_select_by_weights(

def _stack_no_weights(
datasets: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
base = datasets[0].copy()
collected_attrs = [base.attrs]
Expand All @@ -136,32 +130,13 @@ def _stack_no_weights(
except KeyError:
base = base.where(data_arr.isnull(), data_arr)

attrs = _combine_stacked_attrs(collected_attrs, combine_times)
attrs = _combine_stacked_attrs(collected_attrs)
base.attrs = attrs
return base


def _combine_stacked_attrs(collected_attrs: Sequence[Mapping], combine_times: bool) -> dict:
attrs = combine_metadata(*collected_attrs)
if combine_times and ("start_time" in attrs or "end_time" in attrs):
new_start, new_end = _get_combined_start_end_times(collected_attrs)
if new_start:
attrs["start_time"] = new_start
if new_end:
attrs["end_time"] = new_end
return attrs


def _get_combined_start_end_times(metadata_objects: Iterable[Mapping]) -> tuple[datetime | None, datetime | None]:
"""Get the start and end times attributes valid for the entire dataset series."""
start_time = None
end_time = None
for md_obj in metadata_objects:
if "start_time" in md_obj and (start_time is None or md_obj["start_time"] < start_time):
start_time = md_obj["start_time"]
if "end_time" in md_obj and (end_time is None or md_obj["end_time"] > end_time):
end_time = md_obj["end_time"]
return start_time, end_time
def _combine_stacked_attrs(collected_attrs: Sequence[Mapping]) -> dict:
return combine_metadata(*collected_attrs)


def timeseries(datasets):
Expand Down
26 changes: 2 additions & 24 deletions satpy/readers/file_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,9 @@ def combine_info(self, all_infos):
"""
combined_info = combine_metadata(*all_infos)

new_dict = self._combine(all_infos, min, "start_time", "start_orbit")
new_dict.update(self._combine(all_infos, max, "end_time", "end_orbit"))
new_dict = self._combine(all_infos, min, "start_orbit")
new_dict.update(self._combine(all_infos, max, "end_orbit"))
new_dict.update(self._combine_orbital_parameters(all_infos))
new_dict.update(self._combine_time_parameters(all_infos))

try:
area = SwathDefinition(lons=np.ma.vstack([info["area"].lons for info in all_infos]),
Expand Down Expand Up @@ -145,27 +144,6 @@ def _combine_orbital_parameters(self, all_infos):
orb_params_comb.update(self._combine(orb_params, np.mean, *keys))
return {"orbital_parameters": orb_params_comb}

def _combine_time_parameters(self, all_infos):
time_params = [info.get("time_parameters", {}) for info in all_infos]
if not all(time_params):
return {}
# Collect all available keys
time_params_comb = {}
for d in time_params:
time_params_comb.update(d)

start_keys = (
"nominal_start_time",
"observation_start_time",
)
end_keys = (
"nominal_end_time",
"observation_end_time",
)
time_params_comb.update(self._combine(time_params, min, *start_keys))
time_params_comb.update(self._combine(time_params, max, *end_keys))
return {"time_parameters": time_params_comb}

@property
def start_time(self):
"""Get start time."""
Expand Down
15 changes: 5 additions & 10 deletions satpy/tests/multiscene_tests/test_blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,9 @@ def test_blend_two_scenes_bad_blend_type(self, multi_scene_and_weights, groups):
("select_with_weights", _get_expected_stack_select),
("blend_with_weights", _get_expected_stack_blend),
])
@pytest.mark.parametrize("combine_times", [False, True])
def test_blend_two_scenes_using_stack_weighted(self, multi_scene_and_weights, groups,
scene1_with_weights, scene2_with_weights,
combine_times, blend_func, exp_result_func):
blend_func, exp_result_func):
"""Test stacking two scenes using weights.
Here we test that the start and end times can be combined so that they
Expand All @@ -266,7 +265,7 @@ def test_blend_two_scenes_using_stack_weighted(self, multi_scene_and_weights, gr
multi_scene.group(simple_groups)

weights = [weights[0][0], weights[1][0]]
stack_func = partial(stack, weights=weights, blend_type=blend_func, combine_times=combine_times)
stack_func = partial(stack, weights=weights, blend_type=blend_func)
weighted_blend = multi_scene.blend(blend_function=stack_func)

expected = exp_result_func(scene1, scene2)
Expand All @@ -275,12 +274,8 @@ def test_blend_two_scenes_using_stack_weighted(self, multi_scene_and_weights, gr
np.testing.assert_allclose(result.data, expected.data)

_check_stacked_metadata(result, "CloudType")
if combine_times:
assert result.attrs["start_time"] == datetime(2023, 1, 16, 11, 9, 17)
assert result.attrs["end_time"] == datetime(2023, 1, 16, 11, 28, 1, 900000)
else:
assert result.attrs["start_time"] == datetime(2023, 1, 16, 11, 11, 7, 250000)
assert result.attrs["end_time"] == datetime(2023, 1, 16, 11, 20, 11, 950000)
assert result.attrs["start_time"] == datetime(2023, 1, 16, 11, 9, 17)
assert result.attrs["end_time"] == datetime(2023, 1, 16, 11, 28, 1, 900000)

@pytest.fixture()
def datasets_and_weights(self):
Expand Down Expand Up @@ -329,7 +324,7 @@ def test_blend_function_stack_weighted(self, datasets_and_weights, line, column)
input_data["weights"][1][line, :] = 2
input_data["weights"][2][:, column] = 2

stack_with_weights = partial(stack, weights=input_data["weights"], combine_times=False)
stack_with_weights = partial(stack, weights=input_data["weights"])
blend_result = stack_with_weights(input_data["datasets"][0:3])

ds1 = input_data["datasets"][0]
Expand Down
4 changes: 0 additions & 4 deletions satpy/tests/test_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,8 +1420,6 @@ def load(self, arg):
filenames=["/foo.tif"])
register.assert_not_called()
retrieve.assert_not_called()
assert "start_time" in res.attrs
assert "end_time" in res.attrs
assert res.attrs["sensor"] is None
assert "modifiers" not in res.attrs
assert "calibration" not in res.attrs
Expand All @@ -1434,8 +1432,6 @@ def load(self, arg):
res = comp()
Scene.assert_called_once_with(reader="generic_image",
filenames=["data_dir/foo.tif"])
assert "start_time" in res.attrs
assert "end_time" in res.attrs
assert res.attrs["sensor"] is None
assert "modifiers" not in res.attrs
assert "calibration" not in res.attrs
Expand Down

0 comments on commit 25d5357

Please sign in to comment.