Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weighted blended stacking to MultiScene (fixes multi-band handling) #2394

Merged
merged 25 commits into from Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7f878f5
Rename Adam's stacking function, add set_weights_to_zero_where_invali…
lobsiger Feb 17, 2023
b374ff9
Adding selecting with bands and true blending.
lobsiger Feb 17, 2023
a7e480b
Fixed indenting stuff an line break.
lobsiger Feb 17, 2023
5a3bca5
Fixed line break.
lobsiger Feb 17, 2023
6bb4cdd
Cosmetics, maybe should use enumerate() ...
lobsiger Feb 18, 2023
c64d556
Adapted stack() for two blending functions.
lobsiger Feb 21, 2023
9ff0f94
Made one blend function out of two.
lobsiger Feb 21, 2023
b0c0701
Made one select function out of two.
lobsiger Feb 21, 2023
3e110b8
Added start because this is now after the .fillna() step.
lobsiger Feb 21, 2023
20a7de9
Just a test to test my test theory.
lobsiger Feb 21, 2023
1ddea03
Maybe Adams test was not invoked when all passed in my homebrew versi…
lobsiger Feb 21, 2023
e19d301
Got my first idea of an assert statement.
lobsiger Feb 22, 2023
7ec0057
Reword stack docstring in satpy/multiscene.py
djhoese Mar 31, 2023
4d4dcb8
Start refactoring new weighted stacking in MultiScene
djhoese Mar 31, 2023
a655318
Refactor multiscene blend tests to avoid unnecessary test setup
djhoese Apr 6, 2023
54797c0
Improve consistency between multiscene stack functions
djhoese Apr 6, 2023
0210ae4
Consolidate some multiscene blend tests
djhoese Apr 6, 2023
15d8d0c
Add initial tests for weighted blended stacking
djhoese Apr 7, 2023
f2ac7b2
Refactor multiscene blending fixtures
djhoese Apr 10, 2023
f253d61
Add RGB and float tests to multiscene blend tests
djhoese Apr 11, 2023
c6d8dea
Remove TODOs from multiscene regarding overlay/weight handling
djhoese Apr 11, 2023
c969ce7
Move multiscene to its own subpackage
djhoese Apr 11, 2023
53d7c22
Refactor multiscene blend functions to their own module
djhoese Apr 11, 2023
918260e
Make more objects in multiscene module private with `_` prefix
djhoese Apr 11, 2023
4ad2be8
Update docstring of multiscene stack and fix docstring errors in priv…
djhoese Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
146 changes: 111 additions & 35 deletions satpy/multiscene.py
@@ -1,4 +1,4 @@
#!/usr/bin/env python

Check warning on line 1 in satpy/multiscene.py

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

❌ New issue: Lines of Code in a Single File

This module has 634 lines of code, improve code health by reducing it to 600
# -*- coding: utf-8 -*-
# Copyright (c) 2016-2023 Satpy developers
#
Expand All @@ -16,13 +16,15 @@
# You should have received a copy of the GNU General Public License along with
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""MultiScene object to work with multiple timesteps of satellite data."""
from __future__ import annotations

import copy
import logging
import warnings
from datetime import datetime
from queue import Queue
from threading import Thread
from typing import Callable, Iterable, Mapping, Optional, Sequence

import dask.array as da
import numpy as np
Expand All @@ -46,66 +48,140 @@
log = logging.getLogger(__name__)


def stack(datasets, weights=None, combine_times=True):
"""Overlay a series of datasets together.
def stack(
datasets: Sequence[xr.DataArray],
weights: Optional[Sequence[xr.DataArray]] = None,
combine_times: bool = True,
blend_type: str = 'select_with_weights'
) -> xr.DataArray:
"""Combine a series of datasets in different ways.

By default, datasets are stacked on top of each other, so the last one applied is
on top. If a sequence of weights arrays are provided the datasets will
be combined according to those weights. The result will be a composite
dataset where the data in each pixel is coming from the dataset having the
highest weight.
on top. If a sequence of weights (with equal shape) is provided, the datasets will
be combined according to those weights. Datasets can be integer category products
(ex. cloud type), single channels (ex. radiance), or RGB composites. In the
latter case, weights is applied
to each 'R', 'G', 'B' coordinate in the same way. The result will be a composite
dataset where each pixel is constructed in a way depending on ``blend_type``.

"""
if weights:
return _stack_weighted(datasets, weights, combine_times)

base = datasets[0].copy()
for dataset in datasets[1:]:
return _stack_with_weights(datasets, weights, combine_times, blend_type)
return _stack_no_weights(datasets, combine_times)


def _stack_with_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool,
blend_type: str
) -> xr.DataArray:
blend_func = _get_weighted_blending_func(blend_type)
filled_weights = list(_fill_weights_for_invalid_dataset_pixels(datasets, weights))
return blend_func(datasets, filled_weights, combine_times)


def _get_weighted_blending_func(blend_type: str) -> Callable:
WEIGHTED_BLENDING_FUNCS = {
"select_with_weights": _stack_select_by_weights,
"blend_with_weights": _stack_blend_by_weights,
}
blend_func = WEIGHTED_BLENDING_FUNCS.get(blend_type)
if blend_func is None:
raise ValueError(f"Unknown weighted blending type: {blend_type}."
f"Expected one of: {WEIGHTED_BLENDING_FUNCS.keys()}")
return blend_func


def _fill_weights_for_invalid_dataset_pixels(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray]
) -> Iterable[xr.DataArray]:
"""Replace weight valus with 0 where data values are invalid/null."""
has_bands_dims = "bands" in datasets[0].dims
for i, dataset in enumerate(datasets):
# if multi-band only use the red-band
compare_ds = dataset[0] if has_bands_dims else dataset
try:
base = base.where(dataset == dataset.attrs["_FillValue"], dataset)
yield xr.where(compare_ds == compare_ds.attrs["_FillValue"], 0, weights[i])
except KeyError:
base = base.where(dataset.isnull(), dataset)
yield xr.where(compare_ds.isnull(), 0, weights[i])

return base

def _stack_blend_by_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
"""Stack datasets blending overlap using weights."""
attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets], combine_times)

def _stack_weighted(datasets, weights, combine_times):
"""Stack datasets using weights."""
weights = set_weights_to_zero_where_invalid(datasets, weights)
overlays = []
for weight, overlay in zip(weights, datasets):
overlays.append(overlay.fillna(0) * weight)

indices = da.argmax(da.dstack(weights), axis=-1)
attrs = combine_metadata(*[x.attrs for x in datasets])
base = sum(overlays) / sum(weights, start=1.e-9)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lobsiger I'm not sure I agree with this start parameter as a way to handle dividing by 0. What do we think the value should be if weights result to 0?

There are different ways to handle a divide by 0, but I'm kind of leaning towards letting it be NaN in the final result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djhoese I agree that we should not have start=1e-6 when summing the weights in the denominator.
Instead we should divide by 0 outside the data area of passes, probably producing 0/0=None (?).
I took the latest version of your multiscene.py and tested all possible combinations of three
MetopB passes over area eurol10 similar to what I did here:

pytroll/pycoast#95 (comment)

I made images for composites 'natural_color', 'ir108_3d' and channel '4' (IR 10.8, used for 'ir108_3d').
I allowed for all 3 defined blend types, generate=True/False, fill_value=0/255/None. This resulted in
3 x 3 x 2 x 3 = 54 different image files without setting start in the denominator. All these images
looked as expected. Then I set sum(weights, start=1e-6) and made 18 more images for 'blend_with_weights'.
All these 18 files are problematic, not discovered so far because I mainly looked at RGB composites
with fill_value=0. I attach 3 examples left is image as expected, right is wrong image with start=1e-6.
Original files produced are *.png, but I reduced them in size and changed them to *.jpg to save space.

Problems with 'blend_with_weights', sum(weights, start=1e-6), generate=False/True does not matter:

'4' Data Pale (IR 10.8, almost white) fill_value0 No_Data region is BLACK (O.K., but caused by 1e-6)
'4' Data Pale (IR 10.8, almost white) fill_value255 No_Data region is BLACK (instead of white)
'4' Data Pale (IR 10.8, almost white) fill_valueNone No_Data region is BLACK (instead of transparent)

'ir108_3d' Data Dark (almost black) fill_value0 No_Data region is WHITE (instead of black, reversed)
'ir108_3d' Data Dark (almost black) fill_value255 No_Data region is WHITE (O.K., but caused by 1e-6)
'ir108_3d' Data Dark (almost black) fill_valueNone No_Data region is WHITE (instead of transparent)

'natural_color' Data look as expected fill_value0 No_Data region is BLACK (O.K., but caused by 1e-6)
'natural_color' Data look as expected fill_value255 No_Data region is BLACK (instead of white)
'natural_color' Data look as expected fill_valueNone No_Data region is BLACK (instead of transparent)

MetopB-4-eurol10-blend_with_weights-generateTrue-fill_value255-startNoneMetopB-4-eurol10-blend_with_weights-generateTrue-fill_value255-start1e-6

MetopB-ir108_3d-eurol10-blend_with_weights-generateTrue-fill_value0-startNoneMetopB-ir108_3d-eurol10-blend_with_weights-generateTrue-fill_value0-start1e-6

MetopB-natural_color-eurol10-blend_with_weights-generateTrue-fill_value255-startNoneMetopB-natural_color-eurol10-blend_with_weights-generateTrue-fill_value255-start1e-6

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Divide by zero with numpy will generally produce NaNs. In Satpy/trollimage these are considered fill values so they either become transparent via the Alpha band or they get set with your fill_value keyword argument. Based on your comment I think this is what you're seeing. For the start value sum cases, yeah I don't see how that would work properly in real world cases. So I'm not concerned. I think I can continue with the rest of my TODO list then.


dims = datasets[0].dims
blended_array = xr.DataArray(base, dims=dims, attrs=attrs)
return blended_array

if combine_times:
if 'start_time' in attrs and 'end_time' in attrs:
attrs['start_time'], attrs['end_time'] = _get_combined_start_end_times(*[x.attrs for x in datasets])

def _stack_select_by_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
"""Stack datasets selecting pixels using weights."""
indices = da.argmax(da.dstack(weights), axis=-1)
if "bands" in datasets[0].dims:
indices = [indices] * datasets[0].sizes["bands"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lobsiger I refactored this so that regardless of the number of bands in the DataArray (L, LA, RGB, RGBA) it would have indices for each one which I thought was what you wanted here. I'm realizing now, should this ignore alpha bands?


attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets], combine_times)
dims = datasets[0].dims
weighted_array = xr.DataArray(da.choose(indices, datasets), dims=dims, attrs=attrs)
return weighted_array
coords = datasets[0].coords
selected_array = xr.DataArray(da.choose(indices, datasets), dims=dims, coords=coords, attrs=attrs)
return selected_array


def set_weights_to_zero_where_invalid(datasets, weights):
"""Go through the weights and set to pixel values to zero where corresponding datasets are invalid."""
for i, dataset in enumerate(datasets):
def _stack_no_weights(
datasets: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
base = datasets[0].copy()
collected_attrs = [base.attrs]
for data_arr in datasets[1:]:
collected_attrs.append(data_arr.attrs)
try:
weights[i] = xr.where(dataset == dataset.attrs["_FillValue"], 0, weights[i])
base = base.where(data_arr == data_arr.attrs["_FillValue"], data_arr)
except KeyError:
weights[i] = xr.where(dataset.isnull(), 0, weights[i])
base = base.where(data_arr.isnull(), data_arr)

return weights
attrs = _combine_stacked_attrs(collected_attrs, combine_times)
base.attrs = attrs
return base


def _get_combined_start_end_times(*metadata_objects):
def _combine_stacked_attrs(collected_attrs: Sequence[Mapping], combine_times: bool) -> dict:
attrs = combine_metadata(*collected_attrs)
if combine_times and ('start_time' in attrs or 'end_time' in attrs):
new_start, new_end = _get_combined_start_end_times(collected_attrs)
if new_start:
attrs["start_time"] = new_start
if new_end:
attrs["end_time"] = new_end
return attrs


def _get_combined_start_end_times(metadata_objects: Iterable[Mapping]) -> tuple[datetime | None, datetime | None]:
"""Get the start and end times attributes valid for the entire dataset series."""
start_time = datetime.now()
end_time = datetime.fromtimestamp(0)
start_time = None
end_time = None
for md_obj in metadata_objects:
if md_obj['start_time'] < start_time:
if "start_time" in md_obj and (start_time is None or md_obj['start_time'] < start_time):
start_time = md_obj['start_time']
if md_obj['end_time'] > end_time:
if "end_time" in md_obj and (end_time is None or md_obj['end_time'] > end_time):
end_time = md_obj['end_time']

return start_time, end_time


Expand Down
105 changes: 39 additions & 66 deletions satpy/tests/multiscene_tests/test_blend.py
Expand Up @@ -31,15 +31,13 @@
from satpy.tests.multiscene_tests.test_utils import _create_test_area, _create_test_dataset, _create_test_int8_dataset
from satpy.tests.utils import make_dataid

NUM_TEST_ROWS = 2
NUM_TEST_COLS = 3


class TestBlendFuncs:
"""Test individual functions used for blending."""

def setup_method(self):
"""Set up test functions."""
self._line = 2
self._column = 3

@pytest.fixture
def scene1_with_weights(self):
"""Create first test scene with a dataset of weights."""
Expand All @@ -54,7 +52,7 @@
)
scene[dsid1] = _create_test_int8_dataset(name='geo-ct', area=area, values=1)
scene[dsid1].attrs['platform_name'] = 'Meteosat-11'
scene[dsid1].attrs['sensor'] = set({'seviri'})
scene[dsid1].attrs['sensor'] = {'seviri'}
scene[dsid1].attrs['units'] = '1'
scene[dsid1].attrs['long_name'] = 'NWC GEO CT Cloud Type'
scene[dsid1].attrs['orbital_parameters'] = {'satellite_nominal_altitude': 35785863.0,
Expand All @@ -65,8 +63,8 @@

wgt1 = _create_test_dataset(name='geo-ct-wgt', area=area, values=0)

wgt1[self._line, :] = 2
wgt1[:, self._column] = 2
wgt1[NUM_TEST_ROWS, :] = 2
wgt1[:, NUM_TEST_COLS] = 2

dsid2 = make_dataid(
name="geo-cma",
Expand Down Expand Up @@ -95,7 +93,7 @@
)
scene[dsid1] = _create_test_int8_dataset(name='polar-ct', area=area, values=3)
scene[dsid1].attrs['platform_name'] = 'NOAA-18'
scene[dsid1].attrs['sensor'] = set({'avhrr-3'})
scene[dsid1].attrs['sensor'] = {'avhrr-3'}
scene[dsid1].attrs['units'] = '1'
scene[dsid1].attrs['long_name'] = 'SAFNWC PPS CT Cloud Type'
scene[dsid1][-1, :] = scene[dsid1].attrs['_FillValue']
Expand Down Expand Up @@ -150,89 +148,48 @@
expected[-1, :] = scene1['geo-ct'][-1, :]

xr.testing.assert_equal(result, expected.compute())
assert result.attrs['platform_name'] == 'Meteosat-11'
assert result.attrs['sensor'] == set({'seviri'})
assert result.attrs['long_name'] == 'NWC GEO CT Cloud Type'
assert result.attrs['units'] == '1'
assert result.attrs['name'] == 'CloudType'
assert result.attrs['_FillValue'] == 255
assert result.attrs['valid_range'] == [1, 15]

_check_stacked_metadata(result, "CloudType")
assert result.attrs['start_time'] == datetime(2023, 1, 16, 11, 9, 17)
assert result.attrs['end_time'] == datetime(2023, 1, 16, 11, 12, 22)
assert result.attrs['end_time'] == datetime(2023, 1, 16, 11, 28, 1, 900000)

@pytest.mark.parametrize("combine_times", False, True)
def test_blend_two_scenes_using_stack_weighted(self, multi_scene_and_weights, groups,
scene1_with_weights, scene2_with_weights):
scene1_with_weights, scene2_with_weights,
combine_times):
"""Test stacking two scenes using weights - testing that metadata are combined correctly.

Here we test that the start and end times can be combined so that they
describe the start and times of the entire data series.

"""
from functools import partial

multi_scene, weights = multi_scene_and_weights
scene1, weights1 = scene1_with_weights
scene2, weights2 = scene2_with_weights

simple_groups = {DataQuery(name='CloudType'): groups[DataQuery(name='CloudType')]}
multi_scene.group(simple_groups)

weights = [weights[0][0], weights[1][0]]
stack_with_weights = partial(stack, weights=weights)
stack_with_weights = partial(stack, weights=weights, combine_times=combine_times)
weighted_blend = multi_scene.blend(blend_function=stack_with_weights)

expected = scene2['polar-ct']
expected[self._line, :] = scene1['geo-ct'][self._line, :]
expected[:, self._column] = scene1['geo-ct'][:, self._column]
expected[NUM_TEST_ROWS, :] = scene1['geo-ct'][NUM_TEST_ROWS, :]
expected[:, NUM_TEST_COLS] = scene1['geo-ct'][:, NUM_TEST_COLS]
expected[-1, :] = scene1['geo-ct'][-1, :]

result = weighted_blend['CloudType'].compute()
xr.testing.assert_equal(result, expected.compute())

expected_area = _create_test_area()
assert result.attrs['area'] == expected_area
assert 'sensor' not in result.attrs
assert 'platform_name' not in result.attrs
assert 'long_name' not in result.attrs
assert result.attrs['units'] == '1'
assert result.attrs['name'] == 'CloudType'
assert result.attrs['_FillValue'] == 255
assert result.attrs['valid_range'] == [1, 15]

assert result.attrs['start_time'] == datetime(2023, 1, 16, 11, 9, 17)
assert result.attrs['end_time'] == datetime(2023, 1, 16, 11, 28, 1, 900000)

def test_blend_two_scenes_using_stack_weighted_no_time_combination(self, multi_scene_and_weights, groups,
scene1_with_weights, scene2_with_weights):
"""Test stacking two scenes using weights - test that the start and end times are averaged and not combined."""
from functools import partial

multi_scene, weights = multi_scene_and_weights
scene1, weights1 = scene1_with_weights
scene2, weights2 = scene2_with_weights

simple_groups = {DataQuery(name='CloudType'): groups[DataQuery(name='CloudType')]}
multi_scene.group(simple_groups)

weights = [weights[0][0], weights[1][0]]
stack_with_weights = partial(stack, weights=weights, combine_times=False)
weighted_blend = multi_scene.blend(blend_function=stack_with_weights)

result = weighted_blend['CloudType'].compute()

expected_area = _create_test_area()
assert result.attrs['area'] == expected_area
assert 'sensor' not in result.attrs
assert 'platform_name' not in result.attrs
assert 'long_name' not in result.attrs
assert result.attrs['units'] == '1'
assert result.attrs['name'] == 'CloudType'
assert result.attrs['_FillValue'] == 255
assert result.attrs['valid_range'] == [1, 15]

assert result.attrs['start_time'] == datetime(2023, 1, 16, 11, 11, 7, 250000)
assert result.attrs['end_time'] == datetime(2023, 1, 16, 11, 20, 11, 950000)
_check_stacked_metadata(result, "CloudType")
if combine_times:
assert result.attrs['start_time'] == datetime(2023, 1, 16, 11, 9, 17)
assert result.attrs['end_time'] == datetime(2023, 1, 16, 11, 28, 1, 900000)
else:
assert result.attrs['start_time'] == datetime(2023, 1, 16, 11, 11, 7, 250000)
assert result.attrs['end_time'] == datetime(2023, 1, 16, 11, 20, 11, 950000)

Check warning on line 192 in satpy/tests/multiscene_tests/test_blend.py

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

❌ New issue: Excess Number of Function Arguments

TestBlendFuncs.test_blend_two_scenes_using_stack_weighted has 5 arguments, threshold = 4

@pytest.fixture
def datasets_and_weights(self):
Expand Down Expand Up @@ -294,7 +251,6 @@
expected.attrs = combine_metadata(*[x.attrs for x in input_data['datasets'][0:3]])

xr.testing.assert_equal(blend_result.compute(), expected.compute())

assert expected.attrs == blend_result.attrs

def test_blend_function_stack(self, datasets_and_weights):
Expand All @@ -308,8 +264,10 @@

res = stack([ds1, ds2])
expected = ds2.copy()
expected.attrs["start_time"] = ds1.attrs["start_time"]

xr.testing.assert_equal(res.compute(), expected.compute())
assert expected.attrs == res.attrs

def test_timeseries(self, datasets_and_weights):
"""Test the 'timeseries' function."""
Expand All @@ -329,3 +287,18 @@
assert isinstance(res2, xr.DataArray)
assert (2, ds1.shape[0], ds1.shape[1]) == res.shape
assert (ds4.shape[0], ds4.shape[1]+ds5.shape[1]) == res2.shape


def _check_stacked_metadata(data_arr: xr.DataArray, exp_name: str) -> None:
assert data_arr.attrs['units'] == '1'
assert data_arr.attrs['name'] == exp_name
assert data_arr.attrs['_FillValue'] == 255
assert data_arr.attrs['valid_range'] == [1, 15]

expected_area = _create_test_area()
assert data_arr.attrs['area'] == expected_area

# these metadata items don't match between all inputs
assert 'sensor' not in data_arr.attrs
assert 'platform_name' not in data_arr.attrs
assert 'long_name' not in data_arr.attrs