Skip to content

Commit

Permalink
Merge pull request #40 from goodsonr/enhancement-colormap-handling-in…
Browse files Browse the repository at this point in the history
…-colorize

Modify colorize routine to allow colorizing using colormaps with alpha channel
  • Loading branch information
djhoese committed Feb 15, 2019
2 parents 5ff71c7 + 5e9189f commit 8073e70
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 24 deletions.
104 changes: 95 additions & 9 deletions trollimage/tests/test_image.py
Expand Up @@ -800,15 +800,22 @@ def test_save(self):
import dask.array as da
from dask.delayed import Delayed
from trollimage import xrimage
from trollimage.colormap import brbg, Colormap

data = xr.DataArray(np.arange(75).reshape(5, 5, 3) / 75., dims=[
# RGBA colormap
bw = Colormap(
(0.0, (1.0, 1.0, 1.0, 1.0)),
(1.0, (0.0, 0.0, 0.0, 0.5)),
)

data = xr.DataArray(np.arange(75).reshape(5, 5, 3) / 74., dims=[
'y', 'x', 'bands'], coords={'bands': ['R', 'G', 'B']})
img = xrimage.XRImage(data)
with NamedTemporaryFile(suffix='.png') as tmp:
img.save(tmp.name)

# Single band image
data = xr.DataArray(np.arange(75).reshape(15, 5, 1) / 75., dims=[
data = xr.DataArray(np.arange(75).reshape(15, 5, 1) / 74., dims=[
'y', 'x', 'bands'], coords={'bands': ['L']})
# Single band image to JPEG
img = xrimage.XRImage(data)
Expand All @@ -819,15 +826,29 @@ def test_save(self):
with NamedTemporaryFile(suffix='.png') as tmp:
img.save(tmp.name)

data = xr.DataArray(da.from_array(np.arange(75).reshape(5, 5, 3) / 75.,
# Single band image palettized
data = xr.DataArray(np.arange(75).reshape(15, 5, 1) / 74., dims=[
'y', 'x', 'bands'], coords={'bands': ['L']})
# Single band image to JPEG
img = xrimage.XRImage(data)
img.palettize(brbg)
with NamedTemporaryFile(suffix='.png') as tmp:
img.save(tmp.name)
# RGBA colormap
img = xrimage.XRImage(data)
img.palettize(bw)
with NamedTemporaryFile(suffix='.png') as tmp:
img.save(tmp.name)

data = xr.DataArray(da.from_array(np.arange(75).reshape(5, 5, 3) / 74.,
chunks=5),
dims=['y', 'x', 'bands'],
coords={'bands': ['R', 'G', 'B']})
img = xrimage.XRImage(data)
with NamedTemporaryFile(suffix='.png') as tmp:
img.save(tmp.name)

data = data.where(data > (10 / 75.0))
data = data.where(data > (10 / 74.0))
img = xrimage.XRImage(data)
with NamedTemporaryFile(suffix='.png') as tmp:
img.save(tmp.name)
Expand Down Expand Up @@ -1376,7 +1397,13 @@ def test_convert_modes(self):
import dask
import xarray as xr
from trollimage import xrimage
from trollimage.colormap import brbg
from trollimage.colormap import brbg, Colormap

# RGBA colormap
bw = Colormap(
(0.0, (1.0, 1.0, 1.0, 1.0)),
(1.0, (0.0, 0.0, 0.0, 0.5)),
)

arr1 = np.arange(150).reshape(1, 15, 10) / 150.
arr2 = np.append(arr1, np.ones(150).reshape(arr1.shape)).reshape(2, 15, 10)
Expand Down Expand Up @@ -1472,10 +1499,10 @@ def test_convert_modes(self):
img.palettize(brbg)
pal = img.palette

img = img.convert('RGBA')
self.assertTrue(np.issubdtype(img.data.dtype, np.floating))
self.assertTrue(img.mode == 'RGBA')
self.assertTrue(len(img.data.coords['bands']) == 4)
img2 = img.convert('RGBA')
self.assertTrue(np.issubdtype(img2.data.dtype, np.floating))
self.assertTrue(img2.mode == 'RGBA')
self.assertTrue(len(img2.data.coords['bands']) == 4)

# PA -> RGB (float)
img = xrimage.XRImage(dataset3)
Expand All @@ -1488,7 +1515,23 @@ def test_convert_modes(self):

self.assertRaises(ValueError, img.convert, 'A')

# L -> palettize -> RGBA (float) with RGBA colormap
with dask.config.set(scheduler=CustomScheduler(max_computes=0)):
img = xrimage.XRImage(dataset1)
img.palettize(bw)

img2 = img.convert('RGBA')
self.assertTrue(np.issubdtype(img2.data.dtype, np.floating))
self.assertTrue(img2.mode == 'RGBA')
self.assertTrue(len(img2.data.coords['bands']) == 4)
# convert to RGB, use RGBA from colormap regardless
img2 = img.convert('RGB')
self.assertTrue(np.issubdtype(img2.data.dtype, np.floating))
self.assertTrue(img2.mode == 'RGBA')
self.assertTrue(len(img2.data.coords['bands']) == 4)

def test_colorize(self):
"""Test colorize with an RGB colormap."""
import xarray as xr
from trollimage import xrimage
from trollimage.colormap import brbg
Expand Down Expand Up @@ -1594,7 +1637,29 @@ def test_colorize(self):
alpha.reshape((1,) + alpha.shape)))
np.testing.assert_allclose(values, expected)

def test_colorize_rgba(self):
"""Test colorize with an RGBA colormap."""
import xarray as xr
from trollimage import xrimage
from trollimage.colormap import Colormap

# RGBA colormap
bw = Colormap(
(0.0, (1.0, 1.0, 1.0, 1.0)),
(1.0, (0.0, 0.0, 0.0, 0.5)),
)

arr = np.arange(75).reshape(5, 15) / 74.
data = xr.DataArray(arr.copy(), dims=['y', 'x'])
img = xrimage.XRImage(data)
img.colorize(bw)
values = img.data.compute()
self.assertTupleEqual((4, 5, 15), values.shape)
np.testing.assert_allclose(values[:, 0, 0], [1.0, 1.0, 1.0, 1.0], rtol=1e-03)
np.testing.assert_allclose(values[:, -1, -1], [0.0, 0.0, 0.0, 0.5])

def test_palettize(self):
"""Test palettize with an RGB colormap."""
import xarray as xr
from trollimage import xrimage
from trollimage.colormap import brbg
Expand All @@ -1613,6 +1678,27 @@ def test_palettize(self):
[8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 10]]])
np.testing.assert_allclose(values, expected)

def test_palettize_rgba(self):
"""Test palettize with an RGBA colormap."""
import xarray as xr
from trollimage import xrimage
from trollimage.colormap import Colormap

# RGBA colormap
bw = Colormap(
(0.0, (1.0, 1.0, 1.0, 1.0)),
(1.0, (0.0, 0.0, 0.0, 0.5)),
)

arr = np.arange(75).reshape(5, 15) / 74.
data = xr.DataArray(arr.copy(), dims=['y', 'x'])
img = xrimage.XRImage(data)
img.palettize(bw)

values = img.data.values
self.assertTupleEqual((1, 5, 15), values.shape)
self.assertTupleEqual((2, 4), bw.colors.shape)

def test_merge(self):
pass

Expand Down
41 changes: 26 additions & 15 deletions trollimage/xrimage.py
Expand Up @@ -464,28 +464,36 @@ def _from_p(self, mode):
"""Convert the image from P or PA to RGB or RGBA."""
self._check_modes(("P", "PA"))

if self.mode.endswith("A"):
alpha = self.data.sel(bands=["A"]).data
mode = mode + "A" if not mode.endswith("A") else mode
else:
alpha = None

if not self.palette:
raise RuntimeError("Can't convert palettized image, missing palette.")

pal = np.array(self.palette)
pal = da.from_array(pal, chunks=pal.shape)
flat_indexes = self.data.data[0].ravel().astype('int64')
new_shape = (3,) + self.data.shape[1:3]
new_data = pal[flat_indexes].transpose().reshape(new_shape)

if pal.shape[1] == 4:
# colormap's alpha overrides data alpha
mode = "RGBA"
alpha = None
elif self.mode.endswith("A"):
# add a new/fake 'bands' dimension to the end
alpha = self.data.sel(bands="A").data[..., None]
mode = mode + "A" if not mode.endswith("A") else mode
else:
alpha = None

flat_indexes = self.data.sel(bands='P').data.ravel().astype('int64')
dim_sizes = ((key, val) for key, val in self.data.sizes.items() if key != 'bands')
dims, new_shape = zip(*dim_sizes)
dims = dims + ('bands',)
new_shape = new_shape + (pal.shape[1],)
new_data = pal[flat_indexes].reshape(new_shape)
coords = dict(self.data.coords)
coords["bands"] = list(mode)

if alpha is not None:
new_arr = da.concatenate((new_data, alpha), axis=0)
data = xr.DataArray(new_arr, coords=coords, attrs=self.data.attrs, dims=self.data.dims)
new_arr = da.concatenate((new_data, alpha), axis=-1)
data = xr.DataArray(new_arr, coords=coords, attrs=self.data.attrs, dims=dims)
else:
data = xr.DataArray(new_data, coords=coords, attrs=self.data.attrs, dims=self.data.dims)
data = xr.DataArray(new_data, coords=coords, attrs=self.data.attrs, dims=dims)

return data

Expand Down Expand Up @@ -926,9 +934,12 @@ def _colorize(l_data, colormap):
return np.concatenate(channels, axis=0)

new_data = l_data.data.map_blocks(_colorize, colormap,
chunks=(3,) + l_data.data.chunks[1:], dtype=np.float64)
chunks=(colormap.colors.shape[1],) + l_data.data.chunks[1:],
dtype=np.float64)

if alpha is not None:
if colormap.colors.shape[1] == 4:
mode = "RGBA"
elif alpha is not None:
new_data = da.concatenate([new_data, alpha.data], axis=0)
mode = "RGBA"
else:
Expand Down

0 comments on commit 8073e70

Please sign in to comment.