Skip to content

Commit

Permalink
Merge pull request #2623 from pnuu/bugfix-hybrid-green-compute
Browse files Browse the repository at this point in the history
Fix unnecessary Dask `compute()`s in `NDVIHybridGreen` compositor
  • Loading branch information
pnuu committed Nov 2, 2023
2 parents 86c075a + dd2c879 commit d59e467
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
5 changes: 1 addition & 4 deletions satpy/composites/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import logging
import warnings

import dask.array as da

from satpy.composites import GenericCompositor
from satpy.dataset import combine_metadata

Expand Down Expand Up @@ -166,8 +164,7 @@ def __call__(self, projectables, optional_datasets=None, **attrs):

ndvi = (ndvi_input[1] - ndvi_input[0]) / (ndvi_input[1] + ndvi_input[0])

ndvi.data = da.where(ndvi > self.ndvi_min, ndvi, self.ndvi_min)
ndvi.data = da.where(ndvi < self.ndvi_max, ndvi, self.ndvi_max)
ndvi = ndvi.clip(self.ndvi_min, self.ndvi_max)

# Introduce non-linearity to ndvi for non-linear scaling to NIR blend fraction
if self.strength != 1.0: # self._apply_strength() has no effect if strength = 1.0 -> no non-linear behaviour
Expand Down
63 changes: 41 additions & 22 deletions satpy/tests/compositor_tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""Tests for spectral correction compositors."""

import dask
import dask.array as da
import numpy as np
import pytest
import xarray as xr

from satpy.composites.spectral import GreenCorrector, HybridGreen, NDVIHybridGreen, SpectralBlender
from satpy.tests.utils import CustomScheduler


class TestSpectralComposites:
Expand Down Expand Up @@ -83,34 +85,51 @@ class TestNdviHybridGreenCompositor:

def setup_method(self):
"""Initialize channels."""
self.c01 = xr.DataArray(da.from_array([[0.25, 0.30], [0.20, 0.30]], chunks=25),
dims=("y", "x"), attrs={"name": "C02"})
self.c02 = xr.DataArray(da.from_array([[0.25, 0.30], [0.25, 0.35]], chunks=25),
dims=("y", "x"), attrs={"name": "C03"})
self.c03 = xr.DataArray(da.from_array([[0.35, 0.35], [0.28, 0.65]], chunks=25),
dims=("y", "x"), attrs={"name": "C04"})
self.c01 = xr.DataArray(
da.from_array(np.array([[0.25, 0.30], [0.20, 0.30]], dtype=np.float32), chunks=25),
dims=("y", "x"), attrs={"name": "C02"})
self.c02 = xr.DataArray(
da.from_array(np.array([[0.25, 0.30], [0.25, 0.35]], dtype=np.float32), chunks=25),
dims=("y", "x"), attrs={"name": "C03"})
self.c03 = xr.DataArray(
da.from_array(np.array([[0.35, 0.35], [0.28, 0.65]], dtype=np.float32), chunks=25),
dims=("y", "x"), attrs={"name": "C04"})

def test_ndvi_hybrid_green(self):
"""Test General functionality with linear scaling from ndvi to blend fraction."""
comp = NDVIHybridGreen("ndvi_hybrid_green", limits=(0.15, 0.05), prerequisites=(0.51, 0.65, 0.85),
standard_name="toa_bidirectional_reflectance")

# Test General functionality with linear strength (=1.0)
res = comp((self.c01, self.c02, self.c03))
assert isinstance(res, xr.DataArray)
assert isinstance(res.data, da.Array)
assert res.attrs["name"] == "ndvi_hybrid_green"
assert res.attrs["standard_name"] == "toa_bidirectional_reflectance"
data = res.values
with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = NDVIHybridGreen("ndvi_hybrid_green", limits=(0.15, 0.05), prerequisites=(0.51, 0.65, 0.85),
standard_name="toa_bidirectional_reflectance")

# Test General functionality with linear strength (=1.0)
res = comp((self.c01, self.c02, self.c03))
assert isinstance(res, xr.DataArray)
assert isinstance(res.data, da.Array)
assert res.attrs["name"] == "ndvi_hybrid_green"
assert res.attrs["standard_name"] == "toa_bidirectional_reflectance"
data = res.values
np.testing.assert_array_almost_equal(data, np.array([[0.2633, 0.3071], [0.2115, 0.3420]]), decimal=4)

def test_nonliniear_scaling(self):
"""Test non-linear scaling using `strength` term."""
comp = NDVIHybridGreen("ndvi_hybrid_green", limits=(0.15, 0.05), strength=2.0, prerequisites=(0.51, 0.65, 0.85),
standard_name="toa_bidirectional_reflectance")
def test_ndvi_hybrid_green_dtype(self):
"""Test that the datatype is not altered by the compositor."""
with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = NDVIHybridGreen("ndvi_hybrid_green", limits=(0.15, 0.05), prerequisites=(0.51, 0.65, 0.85),
standard_name="toa_bidirectional_reflectance")
res = comp((self.c01, self.c02, self.c03)).compute()
assert res.data.dtype == np.float32

res = comp((self.c01, self.c02, self.c03))
np.testing.assert_array_almost_equal(res.values, np.array([[0.2646, 0.3075], [0.2120, 0.3471]]), decimal=4)
def test_nonlinear_scaling(self):
"""Test non-linear scaling using `strength` term."""
with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
comp = NDVIHybridGreen("ndvi_hybrid_green", limits=(0.15, 0.05), strength=2.0,
prerequisites=(0.51, 0.65, 0.85),
standard_name="toa_bidirectional_reflectance")

res = comp((self.c01, self.c02, self.c03))
res_np = res.data.compute()
assert res.dtype == res_np.dtype
assert res.dtype == np.float32
np.testing.assert_array_almost_equal(res.data, np.array([[0.2646, 0.3075], [0.2120, 0.3471]]), decimal=4)

def test_invalid_strength(self):
"""Test using invalid `strength` term for non-linear scaling."""
Expand Down

0 comments on commit d59e467

Please sign in to comment.