Skip to content

Commit

Permalink
Merge pull request #2617 from pnuu/bugfix-dnc-compute
Browse files Browse the repository at this point in the history
Reduce Dask computations in `DayNightCompositor`
  • Loading branch information
pnuu committed Oct 27, 2023
2 parents 652a236 + 3d6041a commit 86c075a
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 65 deletions.
100 changes: 69 additions & 31 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import os
import warnings
from typing import Optional, Sequence

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -119,7 +120,12 @@ def id(self):
id_keys = self.attrs.get("_satpy_id_keys", minimal_default_keys_config)
return DataID(id_keys, **self.attrs)

def __call__(self, datasets, optional_datasets=None, **info):
def __call__(
self,
datasets: Sequence[xr.DataArray],
optional_datasets: Optional[Sequence[xr.DataArray]] = None,
**info
) -> xr.DataArray:
"""Generate a composite."""
raise NotImplementedError()

Expand Down Expand Up @@ -422,7 +428,12 @@ def _get_sensors(self, projectables):
sensor = list(sensor)[0]
return sensor

def __call__(self, projectables, nonprojectables=None, **attrs):
def __call__(
self,
datasets: Sequence[xr.DataArray],
optional_datasets: Optional[Sequence[xr.DataArray]] = None,
**attrs
) -> xr.DataArray:
"""Build the composite."""
if "deprecation_warning" in self.attrs:
warnings.warn(
Expand All @@ -431,29 +442,29 @@ def __call__(self, projectables, nonprojectables=None, **attrs):
stacklevel=2
)
self.attrs.pop("deprecation_warning", None)
num = len(projectables)
num = len(datasets)
mode = attrs.get("mode")
if mode is None:
# num may not be in `self.modes` so only check if we need to
mode = self.modes[num]
if len(projectables) > 1:
projectables = self.match_data_arrays(projectables)
data = self._concat_datasets(projectables, mode)
if len(datasets) > 1:
datasets = self.match_data_arrays(datasets)
data = self._concat_datasets(datasets, mode)
# Skip masking if user wants it or a specific alpha channel is given.
if self.common_channel_mask and mode[-1] != "A":
data = data.where(data.notnull().all(dim="bands"))
else:
data = projectables[0]
data = datasets[0]

# if inputs have a time coordinate that may differ slightly between
# themselves then find the mid time and use that as the single
# time coordinate value
if len(projectables) > 1:
time = check_times(projectables)
if len(datasets) > 1:
time = check_times(datasets)
if time is not None and "time" in data.dims:
data["time"] = [time]

new_attrs = combine_metadata(*projectables)
new_attrs = combine_metadata(*datasets)
# remove metadata that shouldn't make sense in a composite
new_attrs["wavelength"] = None
new_attrs.pop("units", None)
Expand All @@ -467,7 +478,7 @@ def __call__(self, projectables, nonprojectables=None, **attrs):
new_attrs.update(self.attrs)
if resolution is not None:
new_attrs["resolution"] = resolution
new_attrs["sensor"] = self._get_sensors(projectables)
new_attrs["sensor"] = self._get_sensors(datasets)
new_attrs["mode"] = mode

return xr.DataArray(data=data.data, attrs=new_attrs,
Expand Down Expand Up @@ -692,22 +703,27 @@ def __init__(self, name, lim_low=85., lim_high=88., day_night="day_night", inclu
self._has_sza = False
super(DayNightCompositor, self).__init__(name, **kwargs)

def __call__(self, projectables, **kwargs):
def __call__(
self,
datasets: Sequence[xr.DataArray],
optional_datasets: Optional[Sequence[xr.DataArray]] = None,
**attrs
) -> xr.DataArray:
"""Generate the composite."""
projectables = self.match_data_arrays(projectables)
datasets = self.match_data_arrays(datasets)
# At least one composite is requested.
foreground_data = projectables[0]
foreground_data = datasets[0]

weights = self._get_coszen_blending_weights(projectables)
weights = self._get_coszen_blending_weights(datasets)

# Apply enhancements to the foreground data
foreground_data = enhance2dataset(foreground_data)

if "only" in self.day_night:
attrs = foreground_data.attrs.copy()
fg_attrs = foreground_data.attrs.copy()
day_data, night_data, weights = self._get_data_for_single_side_product(foreground_data, weights)
else:
day_data, night_data, attrs = self._get_data_for_combined_product(foreground_data, projectables[1])
day_data, night_data, fg_attrs = self._get_data_for_combined_product(foreground_data, datasets[1])

# The computed coszen is for the full area, so it needs to be masked for missing and off-swath data
if self.include_alpha and not self._has_sza:
Expand All @@ -718,11 +734,18 @@ def __call__(self, projectables, **kwargs):
day_data = zero_missing_data(day_data, night_data)
night_data = zero_missing_data(night_data, day_data)

data = self._weight_data(day_data, night_data, weights, attrs)
data = self._weight_data(day_data, night_data, weights, fg_attrs)

return super(DayNightCompositor, self).__call__(data, **kwargs)
return super(DayNightCompositor, self).__call__(
data,
optional_datasets=optional_datasets,
**attrs
)

def _get_coszen_blending_weights(self, projectables):
def _get_coszen_blending_weights(
self,
projectables: Sequence[xr.DataArray],
) -> xr.DataArray:
lim_low = np.cos(np.deg2rad(self.lim_low))
lim_high = np.cos(np.deg2rad(self.lim_high))
try:
Expand All @@ -739,7 +762,11 @@ def _get_coszen_blending_weights(self, projectables):

return coszen.clip(0, 1)

def _get_data_for_single_side_product(self, foreground_data, weights):
def _get_data_for_single_side_product(
self,
foreground_data: xr.DataArray,
weights: xr.DataArray,
) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
# Only one portion (day or night) is selected. One composite is requested.
# Add alpha band to single L/RGB composite to make the masked-out portion transparent when needed
# L -> LA
Expand All @@ -754,8 +781,8 @@ def _get_data_for_single_side_product(self, foreground_data, weights):

def _mask_weights(self, weights):
if "day" in self.day_night:
return da.where(weights != 0, weights, np.nan)
return da.where(weights != 1, weights, np.nan)
return weights.where(weights != 0, np.nan)
return weights.where(weights != 1, np.nan)

def _get_day_night_data_for_single_side_product(self, foreground_data):
if "day" in self.day_night:
Expand All @@ -778,25 +805,33 @@ def _get_data_for_combined_product(self, day_data, night_data):

return day_data, night_data, attrs

def _mask_weights_with_data(self, weights, day_data, night_data):
def _mask_weights_with_data(
self,
weights: xr.DataArray,
day_data: xr.DataArray,
night_data: xr.DataArray,
) -> xr.DataArray:
data_a = _get_single_channel(day_data)
data_b = _get_single_channel(night_data)
if "only" in self.day_night:
mask = _get_weight_mask_for_single_side_product(data_a, data_b)
else:
mask = _get_weight_mask_for_daynight_product(weights, data_a, data_b)

return da.where(mask, weights, np.nan)
return weights.where(mask, np.nan)

def _weight_data(self, day_data, night_data, weights, attrs):
def _weight_data(
self,
day_data: xr.DataArray,
night_data: xr.DataArray,
weights: xr.DataArray,
attrs: dict,
) -> list[xr.DataArray]:
if not self.include_alpha:
fill = 1 if self.day_night == "night_only" else 0
weights = da.where(np.isnan(weights), fill, weights)

weights = weights.where(~np.isnan(weights), fill)
data = []
for b in _get_band_names(day_data, night_data):
# if self.day_night == "night_only" and self.include_alpha is False:
# import ipdb; ipdb.set_trace()
day_band = _get_single_band_data(day_data, b)
night_band = _get_single_band_data(night_data, b)
# For day-only and night-only products only the alpha channel is weighted
Expand Down Expand Up @@ -824,9 +859,12 @@ def _get_single_band_data(data, band):
return data.sel(bands=band)


def _get_single_channel(data):
def _get_single_channel(data: xr.DataArray) -> xr.DataArray:
try:
data = data[0, :, :]
# remove coordinates that may be band-specific (ex. "bands")
# and we don't care about anymore
data = data.reset_coords(drop=True)
except (IndexError, TypeError):
pass
return data
Expand Down
2 changes: 1 addition & 1 deletion satpy/modifiers/_crefl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def __call__(self, sensor_azimuth, sensor_zenith, solar_azimuth, solar_zenith, a
def _run_crefl(self, mus, muv, phi, solar_zenith, sensor_zenith, height, coeffs):
raise NotImplementedError()

def _height_from_avg_elevation(self, avg_elevation: Optional[np.ndarray]) -> da.Array:
def _height_from_avg_elevation(self, avg_elevation: Optional[np.ndarray]) -> da.Array | float:
"""Get digital elevation map data for our granule with ocean fill value set to 0."""
if avg_elevation is None:
LOG.debug("No average elevation information provided in CREFL")
Expand Down
89 changes: 56 additions & 33 deletions satpy/tests/test_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pyresample import AreaDefinition

import satpy
from satpy.tests.utils import CustomScheduler

# NOTE:
# The following fixtures are not defined in this file, but are used and injected by Pytest:
Expand Down Expand Up @@ -431,28 +432,34 @@ def setUp(self):
def test_daynight_sza(self):
"""Test compositor with both day and night portions when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b, self.sza))
res = res.compute()
expected = np.array([[0., 0.22122352], [0.5, 1.]])
np.testing.assert_allclose(res.values[0], expected)

def test_daynight_area(self):
"""Test compositor both day and night portions when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b))
res = res.compute()
expected_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
for i in range(3):
np.testing.assert_allclose(res.values[i], expected_channel)

def test_night_only_sza_with_alpha(self):
"""Test compositor with night portion with alpha band when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b, self.sza))
res = res.compute()
expected_red_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[0., 0.33296056], [1., 1.]])
np.testing.assert_allclose(res.values[0], expected_red_channel)
Expand All @@ -461,19 +468,23 @@ def test_night_only_sza_with_alpha(self):
def test_night_only_sza_without_alpha(self):
"""Test compositor with night portion without alpha band when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()
expected = np.array([[0., 0.11042631], [0.66835017, 1.]])
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands

def test_night_only_area_with_alpha(self):
"""Test compositor with night portion with alpha band when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()
expected_l_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[np.nan, 0.], [0., 0.]])
np.testing.assert_allclose(res.values[0], expected_l_channel)
Expand All @@ -482,19 +493,23 @@ def test_night_only_area_with_alpha(self):
def test_night_only_area_without_alpha(self):
"""Test compositor with night portion without alpha band when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_b,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_b,))
res = res.compute()
expected = np.array([[np.nan, 0.], [0., 0.]])
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands

def test_day_only_sza_with_alpha(self):
"""Test compositor with day portion with alpha band when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a, self.sza))
res = res.compute()
expected_red_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
expected_alpha = np.array([[1., 0.66703944], [0., 0.]])
np.testing.assert_allclose(res.values[0], expected_red_channel)
Expand All @@ -503,9 +518,11 @@ def test_day_only_sza_with_alpha(self):
def test_day_only_sza_without_alpha(self):
"""Test compositor with day portion without alpha band when SZA data is included."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()
expected_channel_data = np.array([[0., 0.22122352], [0., 0.]])
for i in range(3):
np.testing.assert_allclose(res.values[i], expected_channel_data)
Expand All @@ -514,9 +531,11 @@ def test_day_only_sza_without_alpha(self):
def test_day_only_area_with_alpha(self):
"""Test compositor with day portion with alpha_band when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a,))
res = res.compute()
expected_l_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
expected_alpha = np.array([[1., 1.], [1., 1.]])
np.testing.assert_allclose(res.values[0], expected_l_channel)
Expand All @@ -525,9 +544,11 @@ def test_day_only_area_with_alpha(self):
def test_day_only_area_with_alpha_and_missing_data(self):
"""Test compositor with day portion with alpha_band when SZA data is not provided and there is missing data."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()
expected_l_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[np.nan, 1.], [1., 1.]])
np.testing.assert_allclose(res.values[0], expected_l_channel)
Expand All @@ -536,9 +557,11 @@ def test_day_only_area_with_alpha_and_missing_data(self):
def test_day_only_area_without_alpha(self):
"""Test compositor with day portion without alpha_band when SZA data is not provided."""
from satpy.composites import DayNightCompositor
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a,))
res = res.compute()

with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a,))
res = res.compute()
expected = np.array([[0., 0.33164983], [0.66835017, 1.]])
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands
Expand Down

0 comments on commit 86c075a

Please sign in to comment.