Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weighted blended stacking to MultiScene (fixes multi-band handling) #2394

Merged
merged 25 commits into from Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7f878f5
Rename Adam's stacking function, add set_weights_to_zero_where_invali…
lobsiger Feb 17, 2023
b374ff9
Adding selecting with bands and true blending.
lobsiger Feb 17, 2023
a7e480b
Fixed indenting stuff an line break.
lobsiger Feb 17, 2023
5a3bca5
Fixed line break.
lobsiger Feb 17, 2023
6bb4cdd
Cosmetics, maybe should use enumerate() ...
lobsiger Feb 18, 2023
c64d556
Adapted stack() for two blending functions.
lobsiger Feb 21, 2023
9ff0f94
Made one blend function out of two.
lobsiger Feb 21, 2023
b0c0701
Made one select function out of two.
lobsiger Feb 21, 2023
3e110b8
Added start because this is now after the .fillna() step.
lobsiger Feb 21, 2023
20a7de9
Just a test to test my test theory.
lobsiger Feb 21, 2023
1ddea03
Maybe Adams test was not invoked when all passed in my homebrew versi…
lobsiger Feb 21, 2023
e19d301
Got my first idea of an assert statement.
lobsiger Feb 22, 2023
7ec0057
Reword stack docstring in satpy/multiscene.py
djhoese Mar 31, 2023
4d4dcb8
Start refactoring new weighted stacking in MultiScene
djhoese Mar 31, 2023
a655318
Refactor multiscene blend tests to avoid unnecessary test setup
djhoese Apr 6, 2023
54797c0
Improve consistency between multiscene stack functions
djhoese Apr 6, 2023
0210ae4
Consolidate some multiscene blend tests
djhoese Apr 6, 2023
15d8d0c
Add initial tests for weighted blended stacking
djhoese Apr 7, 2023
f2ac7b2
Refactor multiscene blending fixtures
djhoese Apr 10, 2023
f253d61
Add RGB and float tests to multiscene blend tests
djhoese Apr 11, 2023
c6d8dea
Remove TODOs from multiscene regarding overlay/weight handling
djhoese Apr 11, 2023
c969ce7
Move multiscene to its own subpackage
djhoese Apr 11, 2023
53d7c22
Refactor multiscene blend functions to their own module
djhoese Apr 11, 2023
918260e
Make more objects in multiscene module private with `_` prefix
djhoese Apr 11, 2023
4ad2be8
Update docstring of multiscene stack and fix docstring errors in priv…
djhoese Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/source/conf.py
Expand Up @@ -99,6 +99,9 @@ def __getattr__(cls, name):
'readers/scatsat1_l2b.py',
]
apidoc_separate_modules = True
apidoc_extra_args = [
"--private",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
Expand Down
4 changes: 4 additions & 0 deletions satpy/multiscene/__init__.py
@@ -0,0 +1,4 @@
"""Functions and classes related to MultiScene functionality."""

from ._blend_funcs import stack, timeseries # noqa
from ._multiscene import MultiScene # noqa
180 changes: 180 additions & 0 deletions satpy/multiscene/_blend_funcs.py
@@ -0,0 +1,180 @@
from __future__ import annotations

from datetime import datetime
from typing import Callable, Iterable, Mapping, Optional, Sequence

import pandas as pd
import xarray as xr
from dask import array as da

from satpy.dataset import combine_metadata


def stack(
data_arrays: Sequence[xr.DataArray],
weights: Optional[Sequence[xr.DataArray]] = None,
combine_times: bool = True,
blend_type: str = 'select_with_weights'
) -> xr.DataArray:
"""Combine a series of datasets in different ways.

By default, DataArrays are stacked on top of each other, so the last one
applied is on top. Each DataArray is assumed to represent the same
geographic region, meaning they have the same area. If a sequence of
weights is provided then they must have the same shape as the area.
Weights with greater than 2 dimensions are not currently supported.

When weights are provided, the DataArrays will be combined according to
those weights. Data can be integer category products (ex. cloud type),
single channels (ex. radiance), or a multi-band composite (ex. an RGB or
RGBA true_color). In the latter case, the weight array is applied
to each band (R, G, B, A) in the same way. The result will be a composite
DataArray where each pixel is constructed in a way depending on ``blend_type``.

Blend type can be one of the following:

* select_with_weights: The input pixel with the maximum weight is chosen.
* blend_with_weights: The final pixel is a weighted average of all valid
input pixels.

"""
if weights:
return _stack_with_weights(data_arrays, weights, combine_times, blend_type)
return _stack_no_weights(data_arrays, combine_times)


def _stack_with_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool,
blend_type: str
) -> xr.DataArray:
blend_func = _get_weighted_blending_func(blend_type)
filled_weights = list(_fill_weights_for_invalid_dataset_pixels(datasets, weights))
return blend_func(datasets, filled_weights, combine_times)


def _get_weighted_blending_func(blend_type: str) -> Callable:
WEIGHTED_BLENDING_FUNCS = {
"select_with_weights": _stack_select_by_weights,
"blend_with_weights": _stack_blend_by_weights,
}
blend_func = WEIGHTED_BLENDING_FUNCS.get(blend_type)
if blend_func is None:
raise ValueError(f"Unknown weighted blending type: {blend_type}."
f"Expected one of: {WEIGHTED_BLENDING_FUNCS.keys()}")
return blend_func


def _fill_weights_for_invalid_dataset_pixels(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray]
) -> Iterable[xr.DataArray]:
"""Replace weight valus with 0 where data values are invalid/null."""
has_bands_dims = "bands" in datasets[0].dims
for i, dataset in enumerate(datasets):
# if multi-band only use the red-band
compare_ds = dataset[0] if has_bands_dims else dataset
try:
yield xr.where(compare_ds == compare_ds.attrs["_FillValue"], 0, weights[i])
except KeyError:
yield xr.where(compare_ds.isnull(), 0, weights[i])


def _stack_blend_by_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
"""Stack datasets blending overlap using weights."""
attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets], combine_times)

overlays = []
for weight, overlay in zip(weights, datasets):
# Any 'overlay' fill values should already be reflected in the weights
# as 0. See _fill_weights_for_invalid_dataset_pixels. We fill NA with
# 0 here to avoid NaNs affecting valid pixels in other datasets. Note
# `.fillna` does not handle the `_FillValue` attribute so this filling
# is purely to remove NaNs.
overlays.append(overlay.fillna(0) * weight)
# NOTE: Currently no way to ignore numpy divide by 0 warnings without
# making a custom map_blocks version of the divide
base = sum(overlays) / sum(weights)

dims = datasets[0].dims
blended_array = xr.DataArray(base, dims=dims, attrs=attrs)
return blended_array


def _stack_select_by_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
"""Stack datasets selecting pixels using weights."""
indices = da.argmax(da.dstack(weights), axis=-1)
if "bands" in datasets[0].dims:
indices = [indices] * datasets[0].sizes["bands"]

attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets], combine_times)
dims = datasets[0].dims
coords = datasets[0].coords
selected_array = xr.DataArray(da.choose(indices, datasets), dims=dims, coords=coords, attrs=attrs)
return selected_array


def _stack_no_weights(
datasets: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
base = datasets[0].copy()
collected_attrs = [base.attrs]
for data_arr in datasets[1:]:
collected_attrs.append(data_arr.attrs)
try:
base = base.where(data_arr == data_arr.attrs["_FillValue"], data_arr)
except KeyError:
base = base.where(data_arr.isnull(), data_arr)

attrs = _combine_stacked_attrs(collected_attrs, combine_times)
base.attrs = attrs
return base


def _combine_stacked_attrs(collected_attrs: Sequence[Mapping], combine_times: bool) -> dict:
attrs = combine_metadata(*collected_attrs)
if combine_times and ('start_time' in attrs or 'end_time' in attrs):
new_start, new_end = _get_combined_start_end_times(collected_attrs)
if new_start:
attrs["start_time"] = new_start
if new_end:
attrs["end_time"] = new_end
return attrs


def _get_combined_start_end_times(metadata_objects: Iterable[Mapping]) -> tuple[datetime | None, datetime | None]:
"""Get the start and end times attributes valid for the entire dataset series."""
start_time = None
end_time = None
for md_obj in metadata_objects:
if "start_time" in md_obj and (start_time is None or md_obj['start_time'] < start_time):
start_time = md_obj['start_time']
if "end_time" in md_obj and (end_time is None or md_obj['end_time'] > end_time):
end_time = md_obj['end_time']
return start_time, end_time


def timeseries(datasets):
"""Expand dataset with and concatenate by time dimension."""
expanded_ds = []
for ds in datasets:
if 'time' not in ds.dims:
tmp = ds.expand_dims("time")
tmp.coords["time"] = pd.DatetimeIndex([ds.attrs["start_time"]])
else:
tmp = ds
expanded_ds.append(tmp)

res = xr.concat(expanded_ds, dim="time")
res.attrs = combine_metadata(*[x.attrs for x in expanded_ds])
return res
121 changes: 28 additions & 93 deletions satpy/multiscene.py → satpy/multiscene/_multiscene.py
Expand Up @@ -16,20 +16,20 @@
# You should have received a copy of the GNU General Public License along with
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""MultiScene object to work with multiple timesteps of satellite data."""
from __future__ import annotations

import copy
import logging
import warnings
from datetime import datetime
from queue import Queue
from threading import Thread
from typing import Callable, Collection, Mapping

import dask.array as da
import numpy as np
import pandas as pd
import xarray as xr

from satpy.dataset import DataID, combine_metadata
from satpy.dataset import DataID
from satpy.scene import Scene
from satpy.writers import get_enhanced_image, split_results

Expand All @@ -46,86 +46,7 @@
log = logging.getLogger(__name__)


def stack(datasets, weights=None, combine_times=True):
"""Overlay a series of datasets together.

By default, datasets are stacked on top of each other, so the last one applied is
on top. If a sequence of weights arrays are provided the datasets will
be combined according to those weights. The result will be a composite
dataset where the data in each pixel is coming from the dataset having the
highest weight.

"""
if weights:
return _stack_weighted(datasets, weights, combine_times)

base = datasets[0].copy()
for dataset in datasets[1:]:
try:
base = base.where(dataset == dataset.attrs["_FillValue"], dataset)
except KeyError:
base = base.where(dataset.isnull(), dataset)

return base


def _stack_weighted(datasets, weights, combine_times):
"""Stack datasets using weights."""
weights = set_weights_to_zero_where_invalid(datasets, weights)

indices = da.argmax(da.dstack(weights), axis=-1)
attrs = combine_metadata(*[x.attrs for x in datasets])

if combine_times:
if 'start_time' in attrs and 'end_time' in attrs:
attrs['start_time'], attrs['end_time'] = _get_combined_start_end_times(*[x.attrs for x in datasets])

dims = datasets[0].dims
weighted_array = xr.DataArray(da.choose(indices, datasets), dims=dims, attrs=attrs)
return weighted_array


def set_weights_to_zero_where_invalid(datasets, weights):
"""Go through the weights and set to pixel values to zero where corresponding datasets are invalid."""
for i, dataset in enumerate(datasets):
try:
weights[i] = xr.where(dataset == dataset.attrs["_FillValue"], 0, weights[i])
except KeyError:
weights[i] = xr.where(dataset.isnull(), 0, weights[i])

return weights


def _get_combined_start_end_times(*metadata_objects):
"""Get the start and end times attributes valid for the entire dataset series."""
start_time = datetime.now()
end_time = datetime.fromtimestamp(0)
for md_obj in metadata_objects:
if md_obj['start_time'] < start_time:
start_time = md_obj['start_time']
if md_obj['end_time'] > end_time:
end_time = md_obj['end_time']

return start_time, end_time


def timeseries(datasets):
"""Expand dataset with and concatenate by time dimension."""
expanded_ds = []
for ds in datasets:
if 'time' not in ds.dims:
tmp = ds.expand_dims("time")
tmp.coords["time"] = pd.DatetimeIndex([ds.attrs["start_time"]])
else:
tmp = ds
expanded_ds.append(tmp)

res = xr.concat(expanded_ds, dim="time")
res.attrs = combine_metadata(*[x.attrs for x in expanded_ds])
return res


def group_datasets_in_scenes(scenes, groups):
def _group_datasets_in_scenes(scenes, groups):
"""Group different datasets in multiple scenes by adding aliases.

Args:
Expand All @@ -140,11 +61,11 @@

"""
for scene in scenes:
grp = GroupAliasGenerator(scene, groups)
grp = _GroupAliasGenerator(scene, groups)
yield grp.duplicate_datasets_with_group_alias()


class GroupAliasGenerator:
class _GroupAliasGenerator:
"""Add group aliases to a scene."""

def __init__(self, scene, groups):
Expand Down Expand Up @@ -271,17 +192,23 @@
return self._scene_gen.first

@classmethod
def from_files(cls, files_to_sort, reader=None,
ensure_all_readers=False, scene_kwargs=None, **kwargs):
def from_files(
cls,
files_to_sort: Collection[str],
reader: str | Collection[str] | None = None,
ensure_all_readers: bool = False,
scene_kwargs: Mapping | None = None,
**kwargs
):
"""Create multiple Scene objects from multiple files.

Args:
files_to_sort (Collection[str]): files to read
reader (str or Collection[str]): reader or readers to use
ensure_all_readers (bool): If True, limit to scenes where all
files_to_sort: files to read
reader: reader or readers to use
ensure_all_readers: If True, limit to scenes where all
readers have at least one file. If False (default), include
all scenes where at least one reader has at least one file.
scene_kwargs (Mapping): additional arguments to pass on to
scene_kwargs: additional arguments to pass on to
:func:`Scene.__init__` for each created scene.

This uses the :func:`satpy.readers.group_files` function to group
Expand Down Expand Up @@ -397,7 +324,10 @@
"""Resample the multiscene."""
return self._generate_scene_func(self._scenes, 'resample', True, destination=destination, **kwargs)

def blend(self, blend_function=stack):
def blend(
self,
blend_function: Callable[..., xr.DataArray] | None = None
) -> Scene:
"""Blend the datasets into one scene.

Reduce the :class:`MultiScene` to a single :class:`~satpy.scene.Scene`. Datasets
Expand All @@ -418,6 +348,11 @@
MultiScene.

"""
if blend_function is None:
# delay importing blend funcs until now in case they aren't used
from ._blend_funcs import stack
blend_function = stack

Check warning on line 354 in satpy/multiscene/_multiscene.py

View check run for this annotation

Codecov / codecov/patch

satpy/multiscene/_multiscene.py#L353-L354

Added lines #L353 - L354 were not covered by tests

new_scn = Scene()
common_datasets = self.shared_dataset_ids
for ds_id in common_datasets:
Expand All @@ -440,7 +375,7 @@
DataQuery('my_group', wavelength=(10, 11, 12)): ['IR_108', 'B13', 'C13']
}
"""
self._scenes = group_datasets_in_scenes(self._scenes, groups)
self._scenes = _group_datasets_in_scenes(self._scenes, groups)

def _distribute_save_datasets(self, scenes_iter, client, batch_size=1, **kwargs):
"""Distribute save_datasets across a cluster."""
Expand Down