Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weighted blended stacking to MultiScene (fixes multi-band handling) #2394

Merged
merged 25 commits into from Apr 17, 2023
Merged
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7f878f5
Rename Adam's stacking function, add set_weights_to_zero_where_invali…
lobsiger Feb 17, 2023
b374ff9
Adding selecting with bands and true blending.
lobsiger Feb 17, 2023
a7e480b
Fixed indenting stuff an line break.
lobsiger Feb 17, 2023
5a3bca5
Fixed line break.
lobsiger Feb 17, 2023
6bb4cdd
Cosmetics, maybe should use enumerate() ...
lobsiger Feb 18, 2023
c64d556
Adapted stack() for two blending functions.
lobsiger Feb 21, 2023
9ff0f94
Made one blend function out of two.
lobsiger Feb 21, 2023
b0c0701
Made one select function out of two.
lobsiger Feb 21, 2023
3e110b8
Added start because this is now after the .fillna() step.
lobsiger Feb 21, 2023
20a7de9
Just a test to test my test theory.
lobsiger Feb 21, 2023
1ddea03
Maybe Adams test was not invoked when all passed in my homebrew versi…
lobsiger Feb 21, 2023
e19d301
Got my first idea of an assert statement.
lobsiger Feb 22, 2023
7ec0057
Reword stack docstring in satpy/multiscene.py
djhoese Mar 31, 2023
4d4dcb8
Start refactoring new weighted stacking in MultiScene
djhoese Mar 31, 2023
a655318
Refactor multiscene blend tests to avoid unnecessary test setup
djhoese Apr 6, 2023
54797c0
Improve consistency between multiscene stack functions
djhoese Apr 6, 2023
0210ae4
Consolidate some multiscene blend tests
djhoese Apr 6, 2023
15d8d0c
Add initial tests for weighted blended stacking
djhoese Apr 7, 2023
f2ac7b2
Refactor multiscene blending fixtures
djhoese Apr 10, 2023
f253d61
Add RGB and float tests to multiscene blend tests
djhoese Apr 11, 2023
c6d8dea
Remove TODOs from multiscene regarding overlay/weight handling
djhoese Apr 11, 2023
c969ce7
Move multiscene to its own subpackage
djhoese Apr 11, 2023
53d7c22
Refactor multiscene blend functions to their own module
djhoese Apr 11, 2023
918260e
Make more objects in multiscene module private with `_` prefix
djhoese Apr 11, 2023
4ad2be8
Update docstring of multiscene stack and fix docstring errors in priv…
djhoese Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
120 changes: 110 additions & 10 deletions satpy/multiscene.py
Expand Up @@ -46,18 +46,29 @@
log = logging.getLogger(__name__)


def stack(datasets, weights=None, combine_times=True):
"""Overlay a series of datasets together.
def stack(datasets, weights=None, combine_times=True, blend_type=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pythonic way to do this is to pass a string rather than a number for the blend type, eg "stack_no_weights", "select_from_highest_weight", "blend_with_weights".

However, I wonder if it would be clearer or cleaner to have separate methods for this, so that instead of calling eg

multiscene.stack(datasets, weights=five_kilos, blend_type="select_with_weights")

we could just call

multiscene.select(datasets, weights=five_kilos)

?

"""Combine a series of datasets in different ways.

By default, datasets are stacked on top of each other, so the last one applied is
on top. If a sequence of weights arrays are provided the datasets will
be combined according to those weights. The result will be a composite
dataset where the data in each pixel is coming from the dataset having the
highest weight.
on top. If a sequence of weights (with equal shape) is provided, the datasets will
be combined according to those weights. Datasets can be integers like 'ct', 'cma',
or radiances single channel or RGB composites. In the later case weights is applied
to each 'R', 'G', 'B' coordinate in the same way. The result will be a composite
dataset where each pixel is constructed in a way depending on variable blend_type.
djhoese marked this conversation as resolved.
Show resolved Hide resolved
blend_type=1 : Each pixel is selected from the dataset with the highest weight
blend_type=2 : Each pixel is blended from all datasets with respective weights
Other blend_types will fallback to stacking the datasets without weights applied.

"""
if weights:
return _stack_weighted(datasets, weights, combine_times)
bands = datasets[0].dims[0] == 'bands'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is likely not an issue in Satpy, but we may want to check if "bands" is a dimension at all and if so extract the individual bands by doing something like data_array.sel(bands=["R"]) where the "R" is actually taken from data_array.coords["bands"].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djhoese regarding your question above "This is essentially a weighted average, right? " the answer is YES. But for blend_type = "blend_with_weights" I also use my special "AzB" type weights that go to zero towards the boundaries of available data. Otherwise we would still see e.g. the ends of the scan lines for multi pass LEO images. In other words for blend_type = "blend_with_weights" a setting like (above) weights = (180. - satellite_zenith_angle") makes no sense.

if weights and bands and blend_type == 1:
return _stack_selected_bands(datasets, weights, combine_times)
if weights and not bands and blend_type == 1:
return _stack_selected_single(datasets, weights, combine_times)
if weights and bands and blend_type == 2:
return _stack_blended_bands(datasets, weights, combine_times)
if weights and not bands and blend_type == 2:
return _stack_blended_single(datasets, weights, combine_times)

base = datasets[0].copy()
for dataset in datasets[1:]:
Expand All @@ -69,8 +80,86 @@ def stack(datasets, weights=None, combine_times=True):
return base


def _stack_weighted(datasets, weights, combine_times):
"""Stack datasets using weights."""
def _stack_blended_bands(datasets, weights, combine_times):
"""Stack datasets with bands blending overlap using weights."""
weights = set_weights_to_zero_where_invalid_red(datasets, weights)

dims = datasets[0].dims
attrs = combine_metadata(*[x.attrs for x in datasets])

if combine_times:
if 'start_time' in attrs and 'end_time' in attrs:
attrs['start_time'], attrs['end_time'] = _get_combined_start_end_times(*[x.attrs for x in datasets])

total = weights[0].copy() + 1.e-9
for n in range(1, len(weights)):
total += weights[n]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
total = weights[0].copy() + 1.e-9
for n in range(1, len(weights)):
total += weights[n]
total = sum(weights)

the sum function even has a "start" argument, where you could put the offset you have here, like total = sum(weights, start=1.e-9)
However, I think this is not necessary. A division by zero later on will result in a NaN value, which will then be replaced with the fillna call.


datasets0 = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this variable should be renamed to something more understandable...

for n in range(0, len(datasets)):
weights[n] /= total
datasets0.append(datasets[n].fillna(0))
datasets0[n][0] *= weights[n]
datasets0[n][1] *= weights[n]
datasets0[n][2] *= weights[n]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for n in range(0, len(datasets)):
weights[n] /= total
datasets0.append(datasets[n].fillna(0))
datasets0[n][0] *= weights[n]
datasets0[n][1] *= weights[n]
datasets0[n][2] *= weights[n]
for weight, dataset in zip(weights, datasets):
weight /= total
datasets0.append(dataset * weight)

I think the array broadcasting mechanism should take care of the weight multiplication. If that works, it would probably mean that we can merge the single and mulitple band cases into one function, right?

Regarding the fillna call, I think we should be careful here about what we use as a fill value. In particular, it's probably good to check if the dataset as a defined _FillValue attribute already.


base = datasets0[0].copy()
for n in range(1, len(datasets0)):
base += datasets0[n]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base = datasets0[0].copy()
for n in range(1, len(datasets0)):
base += datasets0[n]
base = sum(datasets0)


blended_array = xr.DataArray(base, dims=dims, attrs=attrs)
return blended_array


def _stack_blended_single(datasets, weights, combine_times):
"""Stack single channel datasets blending overlap using weights."""
weights = set_weights_to_zero_where_invalid(datasets, weights)

dims = datasets[0].dims
attrs = combine_metadata(*[x.attrs for x in datasets])

if combine_times:
if 'start_time' in attrs and 'end_time' in attrs:
attrs['start_time'], attrs['end_time'] = _get_combined_start_end_times(*[x.attrs for x in datasets])

total = weights[0].copy() + 1.e-9
for n in range(1, len(weights)):
total += weights[n]

datasets0 = []
for n in range(0, len(datasets)):
weights[n] /= total
datasets0.append(datasets[n].fillna(0))
datasets0[n] *= weights[n]

base = datasets0[0].copy()
for n in range(1, len(datasets0)):
base += datasets0[n]

blended_array = xr.DataArray(base, dims=dims, attrs=attrs)
return blended_array


def _stack_selected_bands(datasets, weights, combine_times):
"""Stack datasets with bands selecting pixels using weights."""
weights = set_weights_to_zero_where_invalid_red(datasets, weights)

indices = da.argmax(da.dstack(weights), axis=-1)
attrs = combine_metadata(*[x.attrs for x in datasets])

if combine_times:
if 'start_time' in attrs and 'end_time' in attrs:
attrs['start_time'], attrs['end_time'] = _get_combined_start_end_times(*[x.attrs for x in datasets])

dims = datasets[0].dims
coords = datasets[0].coords
selected_array = xr.DataArray(da.choose([indices, indices, indices], datasets),
coords=coords, dims=dims, attrs=attrs)
return selected_array


def _stack_selected_single(datasets, weights, combine_times):
"""Stack single channel datasets selecting pixels using weights."""
weights = set_weights_to_zero_where_invalid(datasets, weights)

indices = da.argmax(da.dstack(weights), axis=-1)
Expand All @@ -85,6 +174,17 @@ def _stack_weighted(datasets, weights, combine_times):
return weighted_array


def set_weights_to_zero_where_invalid_red(datasets, weights):
"""Go through the weights and set to pixel values to zero where corresponding datasets[0] are invalid."""
for i, dataset in enumerate(datasets):
try:
weights[i] = xr.where(dataset[0] == dataset.attrs["_FillValue"], 0, weights[i])
except KeyError:
weights[i] = xr.where(dataset[0].isnull(), 0, weights[i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lobsiger I'm refactoring this while also writing tests, but noticed that this says datasets[0], should this be datasets[i]? Don't worry about fixing it yourself if this is wrong, I'll do it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! This is only looking at the red channel. Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! This is only looking at the red channel. Why?

@djhoese this is simply because I was unable to write an accepted function that treated all available bands properly. So I just copied the code @adybbroe wrote for single band cloud data and looked at the RED band only. My (probably wrong) assumption was that either none or all bands would be null or have the _FillValue. At least it seemed to correct OLCI ERR data images. Without this function those images had black lines at the ends (left/right) of the swath of each satellite pass.


return weights


def set_weights_to_zero_where_invalid(datasets, weights):
"""Go through the weights and set to pixel values to zero where corresponding datasets are invalid."""
for i, dataset in enumerate(datasets):
Expand Down