Skip to content

Commit

Permalink
Merge pull request #917 from djhoese/bugfix-background-compositor
Browse files Browse the repository at this point in the history
Fix BackgroundCompositor not retaining input metadata
  • Loading branch information
djhoese committed Oct 2, 2019
2 parents 6d7768b + 5e92278 commit 95b0aea
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 53 deletions.
49 changes: 33 additions & 16 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def __call__(self, projectables, **info):
# we were not given SZA, generate SZA then calculate cos(SZA)
from pyorbital.astronomy import cos_zen
LOG.debug("Computing sun zenith angles.")
lons, lats = vis.attrs["area"].get_lonlats_dask(CHUNK_SIZE)
lons, lats = vis.attrs["area"].get_lonlats(chunks=CHUNK_SIZE)

coords = {}
if 'y' in vis.coords and 'x' in vis.coords:
Expand Down Expand Up @@ -512,9 +512,7 @@ def get_angles(self, vis):
from pyorbital.astronomy import get_alt_az, sun_zenith_angle
from pyorbital.orbital import get_observer_look

lons, lats = vis.attrs['area'].get_lonlats_dask(
chunks=vis.data.chunks)

lons, lats = vis.attrs['area'].get_lonlats(chunks=vis.data.chunks)
sunalt, suna = get_alt_az(vis.attrs['start_time'], lons, lats)
suna = np.rad2deg(suna)
sunz = sun_zenith_angle(vis.attrs['start_time'], lons, lats)
Expand Down Expand Up @@ -635,7 +633,7 @@ def _get_reflectance(self, projectables, optional_datasets):
if sun_zenith is None:
if sun_zenith_angle is None:
raise ImportError("No module named pyorbital.astronomy")
lons, lats = _nir.attrs["area"].get_lonlats_dask(CHUNK_SIZE)
lons, lats = _nir.attrs["area"].get_lonlats(chunks=CHUNK_SIZE)
sun_zenith = sun_zenith_angle(_nir.attrs['start_time'], lons, lats)

return self._refl3x.reflectance_from_tbs(sun_zenith, _nir, _tb11, tb_ir_co2=tb13_4)
Expand Down Expand Up @@ -682,7 +680,7 @@ def __call__(self, projectables, optional_datasets=None, **info):
satz = optional_datasets[0]
else:
from pyorbital.orbital import get_observer_look
lons, lats = band.attrs['area'].get_lonlats_dask(CHUNK_SIZE)
lons, lats = band.attrs['area'].get_lonlats(chunks=CHUNK_SIZE)
sat_lon, sat_lat, sat_alt = get_satpos(band)
try:
dummy, satel = get_observer_look(sat_lon,
Expand Down Expand Up @@ -787,10 +785,22 @@ def __init__(self, name, common_channel_mask=True, **kwargs):
Args:
common_channel_mask (bool): If True, mask all the channels with
a mask that combines all the invalid areas of the given data.
"""
self.common_channel_mask = common_channel_mask
super(GenericCompositor, self).__init__(name, **kwargs)

@classmethod
def infer_mode(cls, data_arr):
"""Guess at the mode for a particular DataArray."""
if 'mode' in data_arr.attrs:
return data_arr.attrs['mode']
if 'bands' not in data_arr.dims:
return cls.modes[1]
if 'bands' in data_arr.coords and isinstance(data_arr.coords['bands'][0], str):
return ''.join(data_arr.coords['bands'].values)
return cls.modes[data_arr.sizes['bands']]

def _concat_datasets(self, projectables, mode):
try:
data = xr.concat(projectables, 'bands', coords='minimal')
Expand Down Expand Up @@ -866,7 +876,7 @@ class FillingCompositor(GenericCompositor):

def __call__(self, projectables, nonprojectables=None, **info):
"""Generate the composite."""
projectables = self.check_areas(projectables)
projectables = self.match_data_arrays(projectables)
projectables[1] = projectables[1].fillna(projectables[0])
projectables[2] = projectables[2].fillna(projectables[0])
projectables[3] = projectables[3].fillna(projectables[0])
Expand All @@ -878,7 +888,7 @@ class Filler(GenericCompositor):

def __call__(self, projectables, nonprojectables=None, **info):
"""Generate the composite."""
projectables = self.check_areas(projectables)
projectables = self.match_data_arrays(projectables)
filled_projectable = projectables[0].fillna(projectables[1])
return super(Filler, self).__call__([filled_projectable], **info)

Expand Down Expand Up @@ -1000,6 +1010,7 @@ def __init__(self, name, lim_low=85., lim_high=95., **kwargs):
blending of the given channels
lim_high (float): upper limit of Sun zenith angle for the
blending of the given channels
"""
self.lim_low = lim_low
self.lim_high = lim_high
Expand All @@ -1024,7 +1035,7 @@ def __call__(self, projectables, **kwargs):
chunks = day_data.sel(bands=day_data['bands'][0]).chunks
except KeyError:
chunks = day_data.chunks
lons, lats = day_data.attrs["area"].get_lonlats_dask(chunks)
lons, lats = day_data.attrs["area"].get_lonlats(chunks=chunks)
coszen = xr.DataArray(cos_zen(day_data.attrs["start_time"],
lons, lats),
dims=['y', 'x'],
Expand Down Expand Up @@ -1069,7 +1080,10 @@ def enhance2dataset(dset):
# Clip image data to interval [0.0, 1.0]
data = img.data.clip(0.0, 1.0)
data.attrs = attrs

# remove 'mode' if it is specified since it may have been updated
data.attrs.pop('mode', None)
# update mode since it may have changed (colorized/palettize)
data.attrs['mode'] = GenericCompositor.infer_mode(data)
return data


Expand Down Expand Up @@ -1198,7 +1212,7 @@ class RatioSharpenedRGB(GenericCompositor):
footprint. Note that the input data to this compositor must already be
resampled so all data arrays are the same shape.
Example:
Example::
R_lo - 1000m resolution - shape=(2000, 2000)
G - 1000m resolution - shape=(2000, 2000)
Expand Down Expand Up @@ -1319,7 +1333,7 @@ def _mean4(data, offset=(0, 0), block_id=None):
class SelfSharpenedRGB(RatioSharpenedRGB):
"""Sharpen RGB with ratio of a band with a strided-version of itself.
Example:
Example::
R - 500m resolution - shape=(4000, 4000)
G - 1000m resolution - shape=(2000, 2000)
Expand Down Expand Up @@ -1428,6 +1442,7 @@ def __init__(self, name, filename=None, area=None, **kwargs):
filename (str): Filename of the image to load
area (str): Name of area definition for the image. Optional
for images with built-in area definitions (geotiff)
"""
if filename is None:
raise ValueError("No image configured for static image compositor")
Expand Down Expand Up @@ -1479,7 +1494,7 @@ class BackgroundCompositor(GenericCompositor):

def __call__(self, projectables, *args, **kwargs):
"""Call the compositor."""
projectables = self.check_areas(projectables)
projectables = self.match_data_arrays(projectables)

# Get enhanced datasets
foreground = enhance2dataset(projectables[0])
Expand All @@ -1494,9 +1509,12 @@ def __call__(self, projectables, *args, **kwargs):

# Get merged metadata
attrs = combine_metadata(foreground, background)
if attrs.get('sensor') is None:
# sensor can be a set
attrs['sensor'] = self._get_sensors(projectables)

# Stack the images
if 'A' in foreground.mode:
if 'A' in foreground.attrs['mode']:
# Use alpha channel as weight and blend the two composites
alpha = foreground.sel(bands='A')
data = []
Expand All @@ -1514,6 +1532,5 @@ def __call__(self, projectables, *args, **kwargs):
data = [data.sel(bands=b) for b in data['bands']]

res = super(BackgroundCompositor, self).__call__(data, **kwargs)
res.attrs['area'] = attrs['area']

res.attrs.update(attrs)
return res
2 changes: 1 addition & 1 deletion satpy/etc/composites/seviri.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ composites:
standard_name: overview

colorized_ir_clouds:
compositor: !!python/name:satpy.composites.GenericCompositor
compositor: !!python/name:satpy.composites.SingleBandCompositor
prerequisites:
- name: 'IR_108'
standard_name: colorized_ir_clouds
Expand Down
82 changes: 46 additions & 36 deletions satpy/tests/compositor_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def setUp(self):
lons = da.from_array(lons, lons.shape)
lats = np.array([[40., 41.], [42., 43.]])
lats = da.from_array(lats, lats.shape)
my_area.get_lonlats_dask.return_value = (lons, lats)
my_area.get_lonlats.return_value = (lons, lats)
self.data_a.attrs['area'] = my_area
self.data_b.attrs['area'] = my_area
# not used except to check that it matches the data arrays
Expand Down Expand Up @@ -545,10 +545,10 @@ def test_compositor(self, calculator, apply_modifier_info, sza):
nir.attrs['platform_name'] = platform
nir.attrs['sensor'] = sensor
nir.attrs['name'] = chan_name
get_lonlats_dask = mock.MagicMock()
get_lonlats = mock.MagicMock()
lons, lats = 1, 2
get_lonlats_dask.return_value = (lons, lats)
nir.attrs['area'] = mock.MagicMock(get_lonlats_dask=get_lonlats_dask)
get_lonlats.return_value = (lons, lats)
nir.attrs['area'] = mock.MagicMock(get_lonlats=get_lonlats)
start_time = 1
nir.attrs['start_time'] = start_time
ir_arr = 100 * np.random.random((2, 2))
Expand All @@ -574,7 +574,7 @@ def test_compositor(self, calculator, apply_modifier_info, sza):
refl_from_tbs.reset_mock()

res = comp([nir, ir_], optional_datasets=[], **info)
get_lonlats_dask.assert_called()
get_lonlats.assert_called()
sza.assert_called_with(start_time, lons, lats)
refl_from_tbs.assert_called_with(sunz2, nir, ir_, tb_ir_co2=None)
refl_from_tbs.reset_mock()
Expand Down Expand Up @@ -831,7 +831,7 @@ def test_call(self):
self.assertTrue('modifiers' not in res.attrs)
self.assertIsNone(res.attrs['wavelength'])
self.assertEqual(res.attrs['mode'], 'LA')
self.assertEquals(res.attrs['resolution'], 333)
self.assertEqual(res.attrs['resolution'], 333)


class TestAddBands(unittest.TestCase):
Expand Down Expand Up @@ -949,28 +949,23 @@ def load(self, arg):
self.assertEqual(res.attrs['area'].area_id, 'euro4')


def _enhance2dataset(dataset):
"""Mock the enhance2dataset to return the original data."""
return dataset


class TestBackgroundCompositor(unittest.TestCase):
"""Test case for the background compositor."""

@mock.patch('satpy.composites.combine_metadata')
@mock.patch('satpy.composites.add_bands')
@mock.patch('satpy.composites.enhance2dataset')
@mock.patch('satpy.composites.BackgroundCompositor.check_areas')
def test_call(self, check_areas, e2d, add_bands, combine_metadata):
@mock.patch('satpy.composites.enhance2dataset', _enhance2dataset)
def test_call(self):
"""Test the background compositing."""
from satpy.composites import BackgroundCompositor
import numpy as np

def check_areas_side_effect(projectables):
return projectables

check_areas.side_effect = check_areas_side_effect
comp = BackgroundCompositor("name")

# L mode images
attrs = {'mode': 'L', 'area': 'foo'}
combine_metadata.return_value = attrs

foreground = xr.DataArray(np.array([[[1., 0.5],
[0., np.nan]]]),
dims=('bands', 'y', 'x'),
Expand All @@ -979,16 +974,13 @@ def check_areas_side_effect(projectables):
background = xr.DataArray(np.ones((1, 2, 2)), dims=('bands', 'y', 'x'),
coords={'bands': [c for c in attrs['mode']]},
attrs=attrs)
add_bands.side_effect = [foreground, background]
res = comp([0, 1])
res = comp([foreground, background])
self.assertEqual(res.attrs['area'], 'foo')
self.assertTrue(np.all(res == np.array([[1., 0.5], [0., 1.]])))
self.assertEqual(res.mode, 'L')
self.assertEqual(res.attrs['mode'], 'L')

# LA mode images
attrs = {'mode': 'LA', 'area': 'foo'}
combine_metadata.return_value = attrs

foreground = xr.DataArray(np.array([[[1., 0.5],
[0., np.nan]],
[[0.5, 0.5],
Expand All @@ -999,15 +991,12 @@ def check_areas_side_effect(projectables):
background = xr.DataArray(np.ones((2, 2, 2)), dims=('bands', 'y', 'x'),
coords={'bands': [c for c in attrs['mode']]},
attrs=attrs)
add_bands.side_effect = [foreground, background]
res = comp([0, 1])
res = comp([foreground, background])
self.assertTrue(np.all(res == np.array([[1., 0.75], [0.5, 1.]])))
self.assertEqual(res.mode, 'L')
self.assertEqual(res.attrs['mode'], 'LA')

# RGB mode images
attrs = {'mode': 'RGB', 'area': 'foo'}
combine_metadata.return_value = attrs

foreground = xr.DataArray(np.array([[[1., 0.5],
[0., np.nan]],
[[1., 0.5],
Expand All @@ -1021,17 +1010,14 @@ def check_areas_side_effect(projectables):
coords={'bands': [c for c in attrs['mode']]},
attrs=attrs)

add_bands.side_effect = [foreground, background]
res = comp([0, 1])
res = comp([foreground, background])
self.assertTrue(np.all(res == np.array([[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]]])))
self.assertEqual(res.mode, 'RGB')
self.assertEqual(res.attrs['mode'], 'RGB')

# RGBA mode images
attrs = {'mode': 'RGBA', 'area': 'foo'}
combine_metadata.return_value = attrs

foreground = xr.DataArray(np.array([[[1., 0.5],
[0., np.nan]],
[[1., 0.5],
Expand All @@ -1047,12 +1033,36 @@ def check_areas_side_effect(projectables):
coords={'bands': [c for c in attrs['mode']]},
attrs=attrs)

add_bands.side_effect = [foreground, background]
res = comp([0, 1])
res = comp([foreground, background])
self.assertTrue(np.all(res == np.array([[[1., 0.75], [0.5, 1.]],
[[1., 0.75], [0.5, 1.]],
[[1., 0.75], [0.5, 1.]]])))
self.assertEqual(res.mode, 'RGB')
self.assertEqual(res.attrs['mode'], 'RGBA')

@mock.patch('satpy.composites.enhance2dataset', _enhance2dataset)
def test_multiple_sensors(self):
"""Test the background compositing from multiple sensor data."""
from satpy.composites import BackgroundCompositor
import numpy as np
comp = BackgroundCompositor("name")

# L mode images
attrs = {'mode': 'L', 'area': 'foo'}
foreground = xr.DataArray(np.array([[[1., 0.5],
[0., np.nan]]]),
dims=('bands', 'y', 'x'),
coords={'bands': [c for c in attrs['mode']]},
attrs=attrs.copy())
foreground.attrs['sensor'] = 'abi'
background = xr.DataArray(np.ones((1, 2, 2)), dims=('bands', 'y', 'x'),
coords={'bands': [c for c in attrs['mode']]},
attrs=attrs.copy())
background.attrs['sensor'] = 'glm'
res = comp([foreground, background])
self.assertEqual(res.attrs['area'], 'foo')
self.assertTrue(np.all(res == np.array([[1., 0.5], [0., 1.]])))
self.assertEqual(res.attrs['mode'], 'L')
self.assertEqual(res.attrs['sensor'], {'abi', 'glm'})


def suite():
Expand Down

0 comments on commit 95b0aea

Please sign in to comment.