Skip to content

Commit

Permalink
Merge pull request #2640 from pnuu/keep-daynightcompositor-dtype
Browse files Browse the repository at this point in the history
Keep original dtype in DayNightCompositor
  • Loading branch information
mraspaud committed Nov 27, 2023
2 parents 684f364 + 2ff0b2a commit d4111e6
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 32 deletions.
19 changes: 9 additions & 10 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,9 +713,7 @@ def __call__(
datasets = self.match_data_arrays(datasets)
# At least one composite is requested.
foreground_data = datasets[0]

weights = self._get_coszen_blending_weights(datasets)

# Apply enhancements to the foreground data
foreground_data = enhance2dataset(foreground_data)

Expand Down Expand Up @@ -759,7 +757,6 @@ def _get_coszen_blending_weights(
# Calculate blending weights
coszen -= np.min((lim_high, lim_low))
coszen /= np.abs(lim_low - lim_high)

return coszen.clip(0, 1)

def _get_data_for_single_side_product(
Expand All @@ -786,8 +783,8 @@ def _mask_weights(self, weights):

def _get_day_night_data_for_single_side_product(self, foreground_data):
if "day" in self.day_night:
return foreground_data, 0
return 0, foreground_data
return foreground_data, foreground_data.dtype.type(0)
return foreground_data.dtype.type(0), foreground_data

def _get_data_for_combined_product(self, day_data, night_data):
# Apply enhancements also to night-side data
Expand Down Expand Up @@ -848,15 +845,16 @@ def _weight_data(
def _get_band_names(day_data, night_data):
try:
bands = day_data["bands"]
except TypeError:
except (IndexError, TypeError):
bands = night_data["bands"]
return bands


def _get_single_band_data(data, band):
if isinstance(data, int):
try:
return data.sel(bands=band)
except AttributeError:
return data
return data.sel(bands=band)


def _get_single_channel(data: xr.DataArray) -> xr.DataArray:
Expand All @@ -871,7 +869,7 @@ def _get_single_channel(data: xr.DataArray) -> xr.DataArray:


def _get_weight_mask_for_single_side_product(data_a, data_b):
if isinstance(data_a, int):
if data_b.shape:
return ~da.isnan(data_b)
return ~da.isnan(data_a)

Expand All @@ -894,7 +892,8 @@ def add_alpha_bands(data):
alpha = new_data[0].copy()
alpha.data = da.ones((data.sizes["y"],
data.sizes["x"]),
chunks=new_data[0].chunks)
chunks=new_data[0].chunks,
dtype=data.dtype)
# Rename band to indicate it's alpha
alpha["bands"] = "A"
new_data.append(alpha)
Expand Down
3 changes: 3 additions & 0 deletions satpy/modifiers/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ def get_cos_sza(data_arr: xr.DataArray) -> xr.DataArray:
"""
chunks = _geo_chunks_from_data_arr(data_arr)
lons, lats = _get_valid_lonlats(data_arr.attrs["area"], chunks)
if lons.dtype != data_arr.dtype and np.issubdtype(data_arr.dtype, np.floating):
lons = lons.astype(data_arr.dtype)
lats = lats.astype(data_arr.dtype)
cos_sza = _get_cos_sza(data_arr.attrs["start_time"], lons, lats)
return _geo_dask_to_data_array(cos_sza)

Expand Down
51 changes: 31 additions & 20 deletions satpy/tests/test_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,15 @@ def setUp(self):
start_time = datetime(2018, 1, 1, 18, 0, 0)

# RGB
a = np.zeros((3, 2, 2), dtype=np.float64)
a = np.zeros((3, 2, 2), dtype=np.float32)
a[:, 0, 0] = 0.1
a[:, 0, 1] = 0.2
a[:, 1, 0] = 0.3
a[:, 1, 1] = 0.4
a = da.from_array(a, a.shape)
self.data_a = xr.DataArray(a, attrs={"test": "a", "start_time": start_time},
coords={"bands": bands}, dims=("bands", "y", "x"))
b = np.zeros((3, 2, 2), dtype=np.float64)
b = np.zeros((3, 2, 2), dtype=np.float32)
b[:, 0, 0] = np.nan
b[:, 0, 1] = 0.25
b[:, 1, 0] = 0.50
Expand All @@ -418,7 +418,7 @@ def setUp(self):
self.data_b = xr.DataArray(b, attrs={"test": "b", "start_time": start_time},
coords={"bands": bands}, dims=("bands", "y", "x"))

sza = np.array([[80., 86.], [94., 100.]])
sza = np.array([[80., 86.], [94., 100.]], dtype=np.float32)
sza = da.from_array(sza, sza.shape)
self.sza = xr.DataArray(sza, dims=("y", "x"))

Expand All @@ -442,8 +442,9 @@ def test_daynight_sza(self):
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b, self.sza))
res = res.compute()
expected = np.array([[0., 0.22122352], [0.5, 1.]])
np.testing.assert_allclose(res.values[0], expected)
expected = np.array([[0., 0.22122374], [0.5, 1.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected, rtol=1e-6)

def test_daynight_area(self):
"""Test compositor both day and night portions when SZA data is not provided."""
Expand All @@ -453,7 +454,8 @@ def test_daynight_area(self):
comp = DayNightCompositor(name="dn_test", day_night="day_night")
res = comp((self.data_a, self.data_b))
res = res.compute()
expected_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
expected_channel = np.array([[0., 0.33164983], [0.66835017, 1.]], dtype=np.float32)
assert res.dtype == np.float32
for i in range(3):
np.testing.assert_allclose(res.values[i], expected_channel)

Expand All @@ -465,8 +467,9 @@ def test_night_only_sza_with_alpha(self):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b, self.sza))
res = res.compute()
expected_red_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[0., 0.33296056], [1., 1.]])
expected_red_channel = np.array([[np.nan, 0.], [0.5, 1.]], dtype=np.float32)
expected_alpha = np.array([[0., 0.3329599], [1., 1.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected_red_channel)
np.testing.assert_allclose(res.values[-1], expected_alpha)

Expand All @@ -478,7 +481,8 @@ def test_night_only_sza_without_alpha(self):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()
expected = np.array([[0., 0.11042631], [0.66835017, 1.]])
expected = np.array([[0., 0.11042609], [0.6683502, 1.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands

Expand All @@ -490,8 +494,9 @@ def test_night_only_area_with_alpha(self):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()
expected_l_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[np.nan, 0.], [0., 0.]])
expected_l_channel = np.array([[np.nan, 0.], [0.5, 1.]], dtype=np.float32)
expected_alpha = np.array([[np.nan, 0.], [0., 0.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected_l_channel)
np.testing.assert_allclose(res.values[-1], expected_alpha)

Expand All @@ -503,7 +508,8 @@ def test_night_only_area_without_alpha(self):
comp = DayNightCompositor(name="dn_test", day_night="night_only", include_alpha=False)
res = comp((self.data_b,))
res = res.compute()
expected = np.array([[np.nan, 0.], [0., 0.]])
expected = np.array([[np.nan, 0.], [0., 0.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands

Expand All @@ -515,8 +521,9 @@ def test_day_only_sza_with_alpha(self):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a, self.sza))
res = res.compute()
expected_red_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
expected_alpha = np.array([[1., 0.66703944], [0., 0.]])
expected_red_channel = np.array([[0., 0.33164983], [0.66835017, 1.]], dtype=np.float32)
expected_alpha = np.array([[1., 0.6670401], [0., 0.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected_red_channel)
np.testing.assert_allclose(res.values[-1], expected_alpha)

Expand All @@ -528,7 +535,8 @@ def test_day_only_sza_without_alpha(self):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a, self.sza))
res = res.compute()
expected_channel_data = np.array([[0., 0.22122352], [0., 0.]])
expected_channel_data = np.array([[0., 0.22122373], [0., 0.]], dtype=np.float32)
assert res.dtype == np.float32
for i in range(3):
np.testing.assert_allclose(res.values[i], expected_channel_data)
assert "A" not in res.bands
Expand All @@ -541,8 +549,9 @@ def test_day_only_area_with_alpha(self):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_a,))
res = res.compute()
expected_l_channel = np.array([[0., 0.33164983], [0.66835017, 1.]])
expected_alpha = np.array([[1., 1.], [1., 1.]])
expected_l_channel = np.array([[0., 0.33164983], [0.66835017, 1.]], dtype=np.float32)
expected_alpha = np.array([[1., 1.], [1., 1.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected_l_channel)
np.testing.assert_allclose(res.values[-1], expected_alpha)

Expand All @@ -554,8 +563,9 @@ def test_day_only_area_with_alpha_and_missing_data(self):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=True)
res = comp((self.data_b,))
res = res.compute()
expected_l_channel = np.array([[np.nan, 0.], [0.5, 1.]])
expected_alpha = np.array([[np.nan, 1.], [1., 1.]])
expected_l_channel = np.array([[np.nan, 0.], [0.5, 1.]], dtype=np.float32)
expected_alpha = np.array([[np.nan, 1.], [1., 1.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected_l_channel)
np.testing.assert_allclose(res.values[-1], expected_alpha)

Expand All @@ -567,7 +577,8 @@ def test_day_only_area_without_alpha(self):
comp = DayNightCompositor(name="dn_test", day_night="day_only", include_alpha=False)
res = comp((self.data_a,))
res = res.compute()
expected = np.array([[0., 0.33164983], [0.66835017, 1.]])
expected = np.array([[0., 0.33164983], [0.66835017, 1.]], dtype=np.float32)
assert res.dtype == np.float32
np.testing.assert_allclose(res.values[0], expected)
assert "A" not in res.bands

Expand Down
6 changes: 4 additions & 2 deletions satpy/tests/test_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,15 @@ def test_basic_default_not_provided(self, sunz_ds1, as_32bit):
sunz_ds1 = sunz_ds1.astype(np.float32)
comp = SunZenithCorrector(name="sza_test", modifiers=tuple())
res = comp((sunz_ds1,), test_attr="test")
np.testing.assert_allclose(res.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]]))
np.testing.assert_allclose(res.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]]),
rtol=1e-6)
assert "y" in res.coords
assert "x" in res.coords
ds1 = sunz_ds1.copy().drop_vars(("y", "x"))
res = comp((ds1,), test_attr="test")
res_np = res.compute()
np.testing.assert_allclose(res_np.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]]))
np.testing.assert_allclose(res_np.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]]),
rtol=1e-6)
assert res.dtype == res_np.dtype
assert "y" not in res.coords
assert "x" not in res.coords
Expand Down

0 comments on commit d4111e6

Please sign in to comment.