New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add weighted blended stacking to MultiScene (fixes multi-band handling) #2394
Changes from 5 commits
7f878f5
b374ff9
a7e480b
5a3bca5
6bb4cdd
c64d556
9ff0f94
b0c0701
3e110b8
20a7de9
1ddea03
e19d301
7ec0057
4d4dcb8
a655318
54797c0
0210ae4
15d8d0c
f2ac7b2
f253d61
c6d8dea
c969ce7
53d7c22
918260e
4ad2be8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -46,18 +46,29 @@ | |||||||||||||||||||
log = logging.getLogger(__name__) | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def stack(datasets, weights=None, combine_times=True): | ||||||||||||||||||||
"""Overlay a series of datasets together. | ||||||||||||||||||||
def stack(datasets, weights=None, combine_times=True, blend_type=1): | ||||||||||||||||||||
"""Combine a series of datasets in different ways. | ||||||||||||||||||||
|
||||||||||||||||||||
By default, datasets are stacked on top of each other, so the last one applied is | ||||||||||||||||||||
on top. If a sequence of weights arrays are provided the datasets will | ||||||||||||||||||||
be combined according to those weights. The result will be a composite | ||||||||||||||||||||
dataset where the data in each pixel is coming from the dataset having the | ||||||||||||||||||||
highest weight. | ||||||||||||||||||||
on top. If a sequence of weights (with equal shape) is provided, the datasets will | ||||||||||||||||||||
be combined according to those weights. Datasets can be integers like 'ct', 'cma', | ||||||||||||||||||||
or radiances single channel or RGB composites. In the later case weights is applied | ||||||||||||||||||||
to each 'R', 'G', 'B' coordinate in the same way. The result will be a composite | ||||||||||||||||||||
dataset where each pixel is constructed in a way depending on variable blend_type. | ||||||||||||||||||||
djhoese marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
blend_type=1 : Each pixel is selected from the dataset with the highest weight | ||||||||||||||||||||
blend_type=2 : Each pixel is blended from all datasets with respective weights | ||||||||||||||||||||
Other blend_types will fallback to stacking the datasets without weights applied. | ||||||||||||||||||||
|
||||||||||||||||||||
""" | ||||||||||||||||||||
if weights: | ||||||||||||||||||||
return _stack_weighted(datasets, weights, combine_times) | ||||||||||||||||||||
bands = datasets[0].dims[0] == 'bands' | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is likely not an issue in Satpy, but we may want to check if "bands" is a dimension at all and if so extract the individual bands by doing something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @djhoese regarding your question above "This is essentially a weighted average, right? " the answer is YES. But for blend_type = "blend_with_weights" I also use my special "AzB" type weights that go to zero towards the boundaries of available data. Otherwise we would still see e.g. the ends of the scan lines for multi pass LEO images. In other words for blend_type = "blend_with_weights" a setting like (above) weights = (180. - satellite_zenith_angle") makes no sense. |
||||||||||||||||||||
if weights and bands and blend_type == 1: | ||||||||||||||||||||
return _stack_selected_bands(datasets, weights, combine_times) | ||||||||||||||||||||
if weights and not bands and blend_type == 1: | ||||||||||||||||||||
return _stack_selected_single(datasets, weights, combine_times) | ||||||||||||||||||||
if weights and bands and blend_type == 2: | ||||||||||||||||||||
return _stack_blended_bands(datasets, weights, combine_times) | ||||||||||||||||||||
if weights and not bands and blend_type == 2: | ||||||||||||||||||||
return _stack_blended_single(datasets, weights, combine_times) | ||||||||||||||||||||
|
||||||||||||||||||||
base = datasets[0].copy() | ||||||||||||||||||||
for dataset in datasets[1:]: | ||||||||||||||||||||
|
@@ -69,8 +80,86 @@ def stack(datasets, weights=None, combine_times=True): | |||||||||||||||||||
return base | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def _stack_weighted(datasets, weights, combine_times): | ||||||||||||||||||||
"""Stack datasets using weights.""" | ||||||||||||||||||||
def _stack_blended_bands(datasets, weights, combine_times): | ||||||||||||||||||||
"""Stack datasets with bands blending overlap using weights.""" | ||||||||||||||||||||
weights = set_weights_to_zero_where_invalid_red(datasets, weights) | ||||||||||||||||||||
|
||||||||||||||||||||
dims = datasets[0].dims | ||||||||||||||||||||
attrs = combine_metadata(*[x.attrs for x in datasets]) | ||||||||||||||||||||
|
||||||||||||||||||||
if combine_times: | ||||||||||||||||||||
if 'start_time' in attrs and 'end_time' in attrs: | ||||||||||||||||||||
attrs['start_time'], attrs['end_time'] = _get_combined_start_end_times(*[x.attrs for x in datasets]) | ||||||||||||||||||||
|
||||||||||||||||||||
total = weights[0].copy() + 1.e-9 | ||||||||||||||||||||
for n in range(1, len(weights)): | ||||||||||||||||||||
total += weights[n] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
the sum function even has a "start" argument, where you could put the offset you have here, like |
||||||||||||||||||||
|
||||||||||||||||||||
datasets0 = [] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this variable should be renamed to something more understandable... |
||||||||||||||||||||
for n in range(0, len(datasets)): | ||||||||||||||||||||
weights[n] /= total | ||||||||||||||||||||
datasets0.append(datasets[n].fillna(0)) | ||||||||||||||||||||
datasets0[n][0] *= weights[n] | ||||||||||||||||||||
datasets0[n][1] *= weights[n] | ||||||||||||||||||||
datasets0[n][2] *= weights[n] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think the array broadcasting mechanism should take care of the weight multiplication. If that works, it would probably mean that we can merge the single and mulitple band cases into one function, right? Regarding the |
||||||||||||||||||||
|
||||||||||||||||||||
base = datasets0[0].copy() | ||||||||||||||||||||
for n in range(1, len(datasets0)): | ||||||||||||||||||||
base += datasets0[n] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||
|
||||||||||||||||||||
blended_array = xr.DataArray(base, dims=dims, attrs=attrs) | ||||||||||||||||||||
return blended_array | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def _stack_blended_single(datasets, weights, combine_times): | ||||||||||||||||||||
"""Stack single channel datasets blending overlap using weights.""" | ||||||||||||||||||||
weights = set_weights_to_zero_where_invalid(datasets, weights) | ||||||||||||||||||||
|
||||||||||||||||||||
dims = datasets[0].dims | ||||||||||||||||||||
attrs = combine_metadata(*[x.attrs for x in datasets]) | ||||||||||||||||||||
|
||||||||||||||||||||
if combine_times: | ||||||||||||||||||||
if 'start_time' in attrs and 'end_time' in attrs: | ||||||||||||||||||||
attrs['start_time'], attrs['end_time'] = _get_combined_start_end_times(*[x.attrs for x in datasets]) | ||||||||||||||||||||
|
||||||||||||||||||||
total = weights[0].copy() + 1.e-9 | ||||||||||||||||||||
for n in range(1, len(weights)): | ||||||||||||||||||||
total += weights[n] | ||||||||||||||||||||
|
||||||||||||||||||||
datasets0 = [] | ||||||||||||||||||||
for n in range(0, len(datasets)): | ||||||||||||||||||||
weights[n] /= total | ||||||||||||||||||||
datasets0.append(datasets[n].fillna(0)) | ||||||||||||||||||||
datasets0[n] *= weights[n] | ||||||||||||||||||||
|
||||||||||||||||||||
base = datasets0[0].copy() | ||||||||||||||||||||
for n in range(1, len(datasets0)): | ||||||||||||||||||||
base += datasets0[n] | ||||||||||||||||||||
|
||||||||||||||||||||
blended_array = xr.DataArray(base, dims=dims, attrs=attrs) | ||||||||||||||||||||
return blended_array | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def _stack_selected_bands(datasets, weights, combine_times): | ||||||||||||||||||||
"""Stack datasets with bands selecting pixels using weights.""" | ||||||||||||||||||||
weights = set_weights_to_zero_where_invalid_red(datasets, weights) | ||||||||||||||||||||
|
||||||||||||||||||||
indices = da.argmax(da.dstack(weights), axis=-1) | ||||||||||||||||||||
attrs = combine_metadata(*[x.attrs for x in datasets]) | ||||||||||||||||||||
|
||||||||||||||||||||
if combine_times: | ||||||||||||||||||||
if 'start_time' in attrs and 'end_time' in attrs: | ||||||||||||||||||||
attrs['start_time'], attrs['end_time'] = _get_combined_start_end_times(*[x.attrs for x in datasets]) | ||||||||||||||||||||
|
||||||||||||||||||||
dims = datasets[0].dims | ||||||||||||||||||||
coords = datasets[0].coords | ||||||||||||||||||||
selected_array = xr.DataArray(da.choose([indices, indices, indices], datasets), | ||||||||||||||||||||
coords=coords, dims=dims, attrs=attrs) | ||||||||||||||||||||
return selected_array | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def _stack_selected_single(datasets, weights, combine_times): | ||||||||||||||||||||
"""Stack single channel datasets selecting pixels using weights.""" | ||||||||||||||||||||
weights = set_weights_to_zero_where_invalid(datasets, weights) | ||||||||||||||||||||
|
||||||||||||||||||||
indices = da.argmax(da.dstack(weights), axis=-1) | ||||||||||||||||||||
|
@@ -85,6 +174,17 @@ def _stack_weighted(datasets, weights, combine_times): | |||||||||||||||||||
return weighted_array | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def set_weights_to_zero_where_invalid_red(datasets, weights): | ||||||||||||||||||||
"""Go through the weights and set to pixel values to zero where corresponding datasets[0] are invalid.""" | ||||||||||||||||||||
for i, dataset in enumerate(datasets): | ||||||||||||||||||||
try: | ||||||||||||||||||||
weights[i] = xr.where(dataset[0] == dataset.attrs["_FillValue"], 0, weights[i]) | ||||||||||||||||||||
except KeyError: | ||||||||||||||||||||
weights[i] = xr.where(dataset[0].isnull(), 0, weights[i]) | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lobsiger I'm refactoring this while also writing tests, but noticed that this says There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh! This is only looking at the red channel. Why? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@djhoese this is simply because I was unable to write an accepted function that treated all available bands properly. So I just copied the code @adybbroe wrote for single band cloud data and looked at the RED band only. My (probably wrong) assumption was that either none or all bands would be null or have the _FillValue. At least it seemed to correct OLCI ERR data images. Without this function those images had black lines at the ends (left/right) of the swath of each satellite pass. |
||||||||||||||||||||
|
||||||||||||||||||||
return weights | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def set_weights_to_zero_where_invalid(datasets, weights): | ||||||||||||||||||||
"""Go through the weights and set to pixel values to zero where corresponding datasets are invalid.""" | ||||||||||||||||||||
for i, dataset in enumerate(datasets): | ||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pythonic way to do this is to pass a string rather than a number for the blend type, eg "stack_no_weights", "select_from_highest_weight", "blend_with_weights".
However, I wonder if it would be clearer or cleaner to have separate methods for this, so that instead of calling eg
we could just call
?