Skip to content

Commit

Permalink
Add helper method for checking areas in compositors
Browse files Browse the repository at this point in the history
  • Loading branch information
djhoese committed Mar 15, 2018
1 parent e6099da commit 171bb27
Showing 1 changed file with 60 additions and 19 deletions.
79 changes: 60 additions & 19 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from satpy.utils import sunzen_corr_cos, atmospheric_path_length_correction
from satpy.writers import get_enhanced_image
from satpy import CHUNK_SIZE
from pyresample.geometry import AreaDefinition

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -283,6 +284,58 @@ def apply_modifier_info(self, origin, destination):
elif o.get(k) is not None:
d[k] = o[k]

def remove_coords(self, data_arrays, coords=('x', 'y', 'time')):
new_data_arrays = []
for coord in coords:
for ds in data_arrays:
ds = ds.copy()
if coord in ds.coords:
del ds.coords[coord]
new_data_arrays.append(ds)
return new_data_arrays

def check_areas(self, data_arrays, adjust_coords=True):
if len(data_arrays) == 1:
return data_arrays

if not all(x.shape == data_arrays[0].shape for x in data_arrays[1:]):
raise IncompatibleAreas("Data shapes are not the same in "
"'{}'".format(self.attrs['name']))

areas = [ds.attrs.get('area', None) for ds in data_arrays]
if not areas or any(a is None for a in areas):
raise ValueError("Missing 'area' attribute")

coords_to_adjust = []
if 'x' in data_arrays[0].coords and 'y' in data_arrays[0].coords:
comp_x = data_arrays[0]['x']
comp_y = data_arrays[0]['y']
all_x = all(np.all(comp_x == ds['x']) for ds in data_arrays[1:])
all_y = all(np.all(comp_y == ds['y']) for ds in data_arrays[1:])
matching_coords = all_x and all_y
if not adjust_coords and not matching_coords:
raise IncompatibleAreas("Dataset coordinates do not match")
if adjust_coords and not matching_coords:
coords_to_adjust.extend(['y', 'x'])
return self.remove_coords(data_arrays)

# Check time coords are the same
if 'time' in data_arrays[0].coords:
comp_t = data_arrays[0]['time']
all_times = all(np.all(comp_t == ds['time'])
for ds in data_arrays[1:])
if adjust_coords and not all_times:
coords_to_adjust.append('time')

if coords_to_adjust:
return self.remove_coords(data_arrays, coords_to_adjust)
# FUTURE: Replace the areas with one shared area

if all(areas[0] == x for x in areas[1:]):
LOG.debug("Not all areas are the same in "
"'{}'".format(self.attrs['name']))
raise IncompatibleAreas


class SunZenithCorrectorBase(CompositeBase):

Expand Down Expand Up @@ -382,6 +435,7 @@ def __call__(self, projectables, optional_datasets=None, **info):
sunalt, suna = get_alt_az(vis.attrs['start_time'], lons, lats)
suna = xu.rad2deg(suna)
sunz = sun_zenith_angle(vis.attrs['start_time'], lons, lats)
# FIXME: Make it daskified
sata, satel = get_observer_look(vis.attrs['satellite_longitude'],
vis.attrs['satellite_latitude'],
vis.attrs['satellite_altitude'],
Expand Down Expand Up @@ -577,16 +631,8 @@ class GenericCompositor(CompositeBase):

modes = {1: 'L', 2: 'LA', 3: 'RGB', 4: 'RGBA'}

def check_area_compatibility(self, projectables):
areas = [projectable.attrs.get('area', None)
for projectable in projectables]
areas = [area for area in areas if area is not None]
if areas and areas.count(areas[0]) != len(areas):
LOG.debug("Not all areas are the same in '{}'".format(self.attrs['name']))
raise IncompatibleAreas

def _concat_datasets(self, projectables, mode):
self.check_area_compatibility(projectables)
projectables = self.check_areas(projectables)

try:
data = xr.concat(projectables, 'bands', coords='minimal')
Expand Down Expand Up @@ -1012,16 +1058,10 @@ def __call__(self, datasets, optional_datasets=None, **info):
'the same size. Must resample first.')

new_attrs = {}
p1, p2, p3 = datasets
if optional_datasets:
high_res = optional_datasets[0]
low_res = datasets[["red", "green", "blue"].index(
self.high_resolution_band)]
if high_res.attrs["area"] != low_res.attrs["area"]:
raise IncompatibleAreas("High resolution band is not "
"mapped to the same area as the "
"low resolution bands. Must "
"resample first.")
datasets = self.check_areas(datasets + optional_datasets)
high_res = datasets[-1]
p1, p2, p3 = datasets[:3]
if 'rows_per_scan' in high_res.attrs:
new_attrs.setdefault('rows_per_scan',
high_res.attrs['rows_per_scan'])
Expand Down Expand Up @@ -1055,7 +1095,8 @@ def __call__(self, datasets, optional_datasets=None, **info):
g = p2
b = p3
else:
r, g, b = p1, p2, p3
datasets = self.check_areas(datasets)
r, g, b = datasets[:3]
# combine the masks
mask = ~(da.isnull(r.data) | da.isnull(g.data) | da.isnull(b.data))
r = r.where(mask)
Expand Down

0 comments on commit 171bb27

Please sign in to comment.