Skip to content

Commit

Permalink
Fix integer stretching
Browse files Browse the repository at this point in the history
  • Loading branch information
pnuu committed Nov 24, 2023
1 parent 2137081 commit 2d8e89d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
9 changes: 5 additions & 4 deletions trollimage/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,16 +1631,17 @@ def test_linear_stretch_does_not_affect_alpha(self, dtype):

np.testing.assert_allclose(img.data.values, res, atol=1.e-6)

def test_linear_stretch_uint8(self):
"""Test linear stretch with uint8 data."""
arr = np.arange(75, dtype=np.uint8).reshape(5, 5, 3)
@pytest.mark.parametrize("dtype", (np.uint8, np.uint16, int))
def test_linear_stretch_integer(self, dtype):
"""Test linear stretch with integer data."""
arr = np.arange(75, dtype=dtype).reshape(5, 5, 3)
arr[4, 4, :] = 255
data = xr.DataArray(arr.copy(), dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B']})
img = xrimage.XRImage(data)
img.stretch_linear()

assert img.data.values.min() == pytest.approx(0.0)
assert img.data.values.min() == pytest.approx(-0.0015614156835530892)
assert img.data.values.max() == pytest.approx(1.0960743801652901)

@pytest.mark.parametrize("dtype", (np.float32, np.float64, float))
Expand Down
6 changes: 5 additions & 1 deletion trollimage/xrimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ def crude_stretch(self, min_stretch=None, max_stretch=None):

attrs = self.data.attrs
offset = -min_stretch * scale_factor

self.data = np.multiply(self.data, scale_factor, dtype=scale_factor.dtype) + offset
self.data.attrs = attrs
self.data.attrs.setdefault('enhancement_history', []).append({'scale': scale_factor,
Expand All @@ -1095,8 +1096,11 @@ def _check_stretch_value(self, val, kind='min'):
if isinstance(val, (list, tuple)):
val = self.xrify_tuples(val)

dtype = self.data.dtype
if np.issubdtype(dtype, np.integer):
dtype = np.dtype(np.float32)
try:
val = val.astype(self.data.dtype)
val = val.astype(dtype)
except AttributeError:
val = self.data.dtype.type(val)

Expand Down

0 comments on commit 2d8e89d

Please sign in to comment.