diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index f35be96292..5b5de4abf0 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -68,7 +68,7 @@ def apply_enhancement(data, func, exclude=None, separate=False, if separate: data_arrs = [] - for band_name in bands: + for idx, band_name in enumerate(bands): band_data = data.sel(bands=[band_name]) if band_name in exclude: # don't modify alpha @@ -78,10 +78,10 @@ def apply_enhancement(data, func, exclude=None, separate=False, if pass_dask: dims = band_data.dims coords = band_data.coords - d_arr = func(band_data.data) + d_arr = func(band_data.data, index=idx) band_data = xr.DataArray(d_arr, dims=dims, coords=coords) else: - band_data = func(band_data) + band_data = func(band_data, index=idx) data_arrs.append(band_data) # we assume that the func can add attrs attrs.update(band_data.attrs) @@ -132,21 +132,21 @@ def func(band_data): def lookup(img, **kwargs): """Assign values to channels based on a table.""" luts = np.array(kwargs['luts'], dtype=np.float32) / 255.0 - luts = luts.ravel() - def func(band_data, luts=luts): + def func(band_data, luts=luts, index=-1): # NaN/null values will become 0 - band_data = band_data.clip(0, luts.size - 1).astype(np.uint8) + lut = luts[:, index] if len(luts.shape) == 2 else luts + band_data = band_data.clip(0, lut.size - 1).astype(np.uint8) def _delayed(luts, band_data): # can't use luts.__getitem__ for some reason return luts[band_data] - new_delay = dask.delayed(_delayed)(luts, band_data) + new_delay = dask.delayed(_delayed)(lut, band_data) new_data = da.from_delayed(new_delay, shape=band_data.shape, dtype=luts.dtype) return new_data - return apply_enhancement(img.data, func, pass_dask=True) + return apply_enhancement(img.data, func, separate=True, pass_dask=True) def colorize(img, **kwargs): @@ -206,7 +206,7 @@ def create_colormap(palette): cmap.append((value, tuple(color))) return Colormap(*cmap) - if isinstance(colors, basestring): + if isinstance(colors, str): from trollimage import colormap import copy return copy.copy(getattr(colormap, colors)) @@ -224,7 +224,9 @@ def three_d_effect(img, **kwargs): [-w, 0, w]]) mode = kwargs.get('convolve_mode', 'same') - def func(band_data, kernel=kernel, mode=mode): + def func(band_data, kernel=kernel, mode=mode, index=None): + del index + def _delayed(band_data, kernel, mode): band_data = band_data.reshape(band_data.shape[1:]) new_data = convolve2d(band_data, kernel, mode=mode) diff --git a/satpy/tests/test_enhancements.py b/satpy/tests/test_enhancements.py index 5ba3e8a3a2..ffd0487fb1 100644 --- a/satpy/tests/test_enhancements.py +++ b/satpy/tests/test_enhancements.py @@ -74,6 +74,16 @@ def test_lookup(self): lut = np.arange(256.) self._test_enhancement(lookup, self.ch1, expected, luts=lut) + expected = np.array([[[0., 0., 0., 0.333333, 0.705882], + [1., 1., 1., 1., 1.]], + [[0., 0., 0., 0.333333, 0.705882], + [1., 1., 1., 1., 1.]], + [[0., 0., 0., 0.333333, 0.705882], + [1., 1., 1., 1., 1.]]]) + lut = np.arange(256.) + lut = np.vstack((lut, lut, lut)).T + self._test_enhancement(lookup, self.rgb, expected, luts=lut) + def test_colorize(self): from satpy.enhancements import colorize from trollimage.colormap import brbg