Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix enhance2dataset to support P-mode datasets #1432

Merged
merged 4 commits into from Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 29 additions & 8 deletions satpy/composites/__init__.py
Expand Up @@ -559,12 +559,13 @@ def __call__(self, projectables, **kwargs):
return super(DayNightCompositor, self).__call__(data, **kwargs)


def enhance2dataset(dset):
"""Return the enhancemened to dataset *dset* as an array."""
def enhance2dataset(dset, convert_p=False):
"""Return the enhancement dataset *dset* as an array.

If `convert_p` is True, enhancements generating a P mode will be converted to RGB or RGBA.
"""
attrs = dset.attrs
img = get_enhanced_image(dset)
# Clip image data to interval [0.0, 1.0]
data = img.data.clip(0.0, 1.0)
data = _get_data_from_enhanced_image(dset, convert_p)
data.attrs = attrs
# remove 'mode' if it is specified since it may have been updated
data.attrs.pop('mode', None)
Expand All @@ -573,10 +574,31 @@ def enhance2dataset(dset):
return data


def _get_data_from_enhanced_image(dset, convert_p):
img = get_enhanced_image(dset)
if convert_p and img.mode == 'P':
img = _apply_palette_to_image(img)
if img.mode != 'P':
data = img.data.clip(0.0, 1.0)
else:
data = img.data
return data


def _apply_palette_to_image(img):
if len(img.palette[0]) == 3:
img = img.convert('RGB')
elif len(img.palette[0]) == 4:
img = img.convert('RGBA')
return img


def add_bands(data, bands):
"""Add bands so that they match *bands*."""
# Add R, G and B bands, remove L band
bands = bands.compute()
if 'P' in data['bands'].data or 'P' in bands.data:
raise NotImplementedError('Cannot mix datasets of mode P with other datasets at the moment.')
if 'L' in data['bands'].data and 'R' in bands.data:
lum = data.sel(bands='L')
# Keep 'A' if it was present
Expand Down Expand Up @@ -1021,9 +1043,8 @@ def __call__(self, projectables, *args, **kwargs):
"""Call the compositor."""
projectables = self.match_data_arrays(projectables)
# Get enhanced datasets
foreground = enhance2dataset(projectables[0])
background = enhance2dataset(projectables[1])

foreground = enhance2dataset(projectables[0], convert_p=True)
background = enhance2dataset(projectables[1], convert_p=True)
# Adjust bands so that they match
# L/RGB -> RGB/RGB
# LA/RGB -> RGBA/RGBA
Expand Down
92 changes: 86 additions & 6 deletions satpy/tests/test_composites.py
Expand Up @@ -25,6 +25,7 @@
import dask
import dask.array as da
import numpy as np
import pytest
import xarray as xr


Expand Down Expand Up @@ -367,7 +368,7 @@ def test_compositor(self):
# Three shades of grey
rgb_arr = np.array([1, 50, 100, 200, 1, 50, 100, 200, 1, 50, 100, 200])
rgb = xr.DataArray(rgb_arr.reshape((3, 2, 2)),
dims=['bands', 'y', 'x'])
dims=['bands', 'y', 'x'], coords={'bands': ['R', 'G', 'B']})
# 100 % luminance -> all result values ~1.0
lum = xr.DataArray(np.array([[100., 100.], [100., 100.]]),
dims=['y', 'x'])
Expand Down Expand Up @@ -778,12 +779,9 @@ def test_call(self):
class TestAddBands(unittest.TestCase):
"""Test case for the `add_bands` function."""

def test_add_bands(self):
def test_add_bands_l_rgb(self):
"""Test adding bands."""
from satpy.composites import add_bands
import dask.array as da
import numpy as np
import xarray as xr

# L + RGB -> RGB
data = xr.DataArray(da.ones((1, 3, 3)), dims=('bands', 'y', 'x'),
Expand All @@ -796,6 +794,10 @@ def test_add_bands(self):
np.testing.assert_array_equal(res.bands, res_bands)
np.testing.assert_array_equal(res.coords['bands'], res_bands)

def test_add_bands_l_rgba(self):
"""Test adding bands."""
from satpy.composites import add_bands

# L + RGBA -> RGBA
data = xr.DataArray(da.ones((1, 3, 3)), dims=('bands', 'y', 'x'),
coords={'bands': ['L']}, attrs={'mode': 'L'})
Expand All @@ -807,6 +809,10 @@ def test_add_bands(self):
np.testing.assert_array_equal(res.bands, res_bands)
np.testing.assert_array_equal(res.coords['bands'], res_bands)

def test_add_bands_la_rgb(self):
"""Test adding bands."""
from satpy.composites import add_bands

# LA + RGB -> RGBA
data = xr.DataArray(da.ones((2, 3, 3)), dims=('bands', 'y', 'x'),
coords={'bands': ['L', 'A']}, attrs={'mode': 'LA'})
Expand All @@ -818,6 +824,10 @@ def test_add_bands(self):
np.testing.assert_array_equal(res.bands, res_bands)
np.testing.assert_array_equal(res.coords['bands'], res_bands)

def test_add_bands_rgb_rbga(self):
"""Test adding bands."""
from satpy.composites import add_bands

# RGB + RGBA -> RGBA
data = xr.DataArray(da.ones((3, 3, 3)), dims=('bands', 'y', 'x'),
coords={'bands': ['R', 'G', 'B']},
Expand All @@ -830,6 +840,19 @@ def test_add_bands(self):
np.testing.assert_array_equal(res.bands, res_bands)
np.testing.assert_array_equal(res.coords['bands'], res_bands)

def test_add_bands_p_l(self):
"""Test adding bands."""
from satpy.composites import add_bands

# P(RGBA) + L -> RGBA
data = xr.DataArray(da.ones((1, 3, 3)), dims=('bands', 'y', 'x'),
coords={'bands': ['P']},
attrs={'mode': 'P'})
new_bands = xr.DataArray(da.array(['L']), dims=('bands'),
coords={'bands': ['L']})
with pytest.raises(NotImplementedError):
add_bands(data, new_bands)


class TestStaticImageCompositor(unittest.TestCase):
"""Test case for the static compositor."""
Expand Down Expand Up @@ -896,7 +919,7 @@ def load(self, arg):
self.assertEqual(comp.filename, "/path/to/image/foo.tif")


def _enhance2dataset(dataset):
def _enhance2dataset(dataset, convert_p=False):
"""Mock the enhance2dataset to return the original data."""
return dataset

Expand Down Expand Up @@ -1242,6 +1265,63 @@ def temp_func(*args):
self.assertEqual(res[2], projectables[2])


class TestEnhance2Dataset(unittest.TestCase):
"""Test the enhance2dataset utility."""

@mock.patch('satpy.composites.get_enhanced_image')
def test_enhance_p_to_rgb(self, get_enhanced_image):
"""Test enhancing a paletted dataset in RGB mode."""
from trollimage.xrimage import XRImage
img = XRImage(xr.DataArray(np.ones((1, 20, 20)) * 2, dims=('bands', 'y', 'x'), coords={'bands': ['P']}))
img.palette = ((0, 0, 0), (4, 4, 4), (8, 8, 8))
get_enhanced_image.return_value = img

from satpy.composites import enhance2dataset
dataset = xr.DataArray(np.ones((1, 20, 20)))
res = enhance2dataset(dataset, convert_p=True)
assert res.attrs['mode'] == 'RGB'

@mock.patch('satpy.composites.get_enhanced_image')
def test_enhance_p_to_rgba(self, get_enhanced_image):
"""Test enhancing a paletted dataset in RGBA mode."""
from trollimage.xrimage import XRImage
img = XRImage(xr.DataArray(np.ones((1, 20, 20)) * 2, dims=('bands', 'y', 'x'), coords={'bands': ['P']}))
img.palette = ((0, 0, 0, 255), (4, 4, 4, 255), (8, 8, 8, 255))
get_enhanced_image.return_value = img

from satpy.composites import enhance2dataset
dataset = xr.DataArray(np.ones((1, 20, 20)))
res = enhance2dataset(dataset, convert_p=True)
assert res.attrs['mode'] == 'RGBA'

@mock.patch('satpy.composites.get_enhanced_image')
def test_enhance_p(self, get_enhanced_image):
"""Test enhancing a paletted dataset in P mode."""
from trollimage.xrimage import XRImage
img = XRImage(xr.DataArray(np.ones((1, 20, 20)) * 2, dims=('bands', 'y', 'x'), coords={'bands': ['P']}))
img.palette = ((0, 0, 0, 255), (4, 4, 4, 255), (8, 8, 8, 255))
get_enhanced_image.return_value = img

from satpy.composites import enhance2dataset
dataset = xr.DataArray(np.ones((1, 20, 20)))
res = enhance2dataset(dataset)
assert res.attrs['mode'] == 'P'
assert res.max().values == 2

@mock.patch('satpy.composites.get_enhanced_image')
def test_enhance_l(self, get_enhanced_image):
"""Test enhancing a paletted dataset in P mode."""
from trollimage.xrimage import XRImage
img = XRImage(xr.DataArray(np.ones((1, 20, 20)) * 2, dims=('bands', 'y', 'x'), coords={'bands': ['L']}))
get_enhanced_image.return_value = img

from satpy.composites import enhance2dataset
dataset = xr.DataArray(np.ones((1, 20, 20)))
res = enhance2dataset(dataset)
assert res.attrs['mode'] == 'L'
assert res.max().values == 1


class TestInferMode(unittest.TestCase):
"""Test the infer_mode utility."""

Expand Down