From 2ad06ed64a221d777ebbc6a573f71b28447ebabe Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 30 Oct 2018 20:03:09 -0500 Subject: [PATCH 1/3] Refactor SCMI writer to be dask friendly Includes support for RGB datasets --- satpy/etc/writers/scmi.yaml | 2 +- satpy/tests/writer_tests/test_scmi.py | 107 +++-- satpy/writers/scmi.py | 543 ++++++++++++++------------ 3 files changed, 358 insertions(+), 294 deletions(-) diff --git a/satpy/etc/writers/scmi.yaml b/satpy/etc/writers/scmi.yaml index e8d5963fc4..55f21c4546 100644 --- a/satpy/etc/writers/scmi.yaml +++ b/satpy/etc/writers/scmi.yaml @@ -48,7 +48,7 @@ sectors: upper_right_xy: [5433893.2095645051, 5433892.6923244298] resolution: [2500000, 2500000] projection: '+proj=geos +lon_0=-105.0 +h=35786023.0 +a=6378137.0 +b=6356752.31414 +sweep=x +units=m +no_defs' - AHI_HIMAWARI8: + AHI Full Disk: lower_left_xy: [-5499999.901174725, -5499999.901174725] upper_right_xy: [5499999.901174725, 5499999.901174725] resolution: [2500000, 2500000] diff --git a/satpy/tests/writer_tests/test_scmi.py b/satpy/tests/writer_tests/test_scmi.py index 913c666d69..264960fb9f 100644 --- a/satpy/tests/writer_tests/test_scmi.py +++ b/satpy/tests/writer_tests/test_scmi.py @@ -1,11 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2017 David Hoese -# -# Author(s): -# -# David Hoese +# Copyright (c) 2017-2018 SatPy Developers # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -23,9 +19,11 @@ """ import os import sys +from glob import glob from datetime import datetime, timedelta import numpy as np +import dask.array as da try: from unittest import mock @@ -75,9 +73,9 @@ def test_basic_numbered_1_tile(self): y_size=200, area_extent=(-1000., -1500., 1000., 1500.), ) - now = datetime.utcnow() + now = datetime(2018, 1, 1, 12, 0, 0) ds = DataArray( - np.linspace(0., 1., 20000, dtype=np.float32).reshape((200, 100)), + da.from_array(np.linspace(0., 1., 20000, dtype=np.float32).reshape((200, 100)), chunks=50), attrs=dict( name='test_ds', platform_name='PLAT', @@ -87,8 +85,10 @@ def test_basic_numbered_1_tile(self): start_time=now, end_time=now + timedelta(minutes=20)) ) - fn = w.save_datasets([ds], sector_id='TEST', source_name="TESTS") - self.assertTrue(os.path.isfile(fn)) + w.save_datasets([ds], sector_id='TEST', source_name='TESTS') + all_files = glob(os.path.join(self.base_dir, 'TESTS_AII*.nc')) + self.assertEqual(len(all_files), 1) + self.assertEqual(os.path.basename(all_files[0]), 'TESTS_AII_PLAT_SENSOR_test_ds_TEST_T001_20180101_1200.nc') def test_basic_numbered_tiles(self): """Test creating a multiple numbered tiles""" @@ -106,9 +106,9 @@ def test_basic_numbered_tiles(self): y_size=200, area_extent=(-1000., -1500., 1000., 1500.), ) - now = datetime.utcnow() + now = datetime(2018, 1, 1, 12, 0, 0) ds = DataArray( - np.linspace(0., 1., 20000, dtype=np.float32).reshape((200, 100)), + da.from_array(np.linspace(0., 1., 20000, dtype=np.float32).reshape((200, 100)), chunks=50), attrs=dict( name='test_ds', platform_name='PLAT', @@ -118,13 +118,9 @@ def test_basic_numbered_tiles(self): start_time=now, end_time=now + timedelta(minutes=20)) ) - fn = w.save_datasets([ds], - sector_id='TEST', - source_name="TESTS", - tile_count=(3, 3)) - # `fn` is currently the last file created - self.assertTrue(os.path.isfile(fn)) - self.assertIn('T009', fn) + w.save_datasets([ds], sector_id='TEST', source_name="TESTS", tile_count=(3, 3)) + all_files = glob(os.path.join(self.base_dir, 'TESTS_AII*.nc')) + self.assertEqual(len(all_files), 9) def test_basic_lettered_tiles(self): """Test creating a lettered grid""" @@ -142,9 +138,9 @@ def test_basic_lettered_tiles(self): y_size=2000, area_extent=(-1000000., -1500000., 1000000., 1500000.), ) - now = datetime.utcnow() + now = datetime(2018, 1, 1, 12, 0, 0) ds = DataArray( - np.linspace(0., 1., 2000000, dtype=np.float32).reshape((2000, 1000)), + da.from_array(np.linspace(0., 1., 2000000, dtype=np.float32).reshape((2000, 1000)), chunks=500), attrs=dict( name='test_ds', platform_name='PLAT', @@ -154,13 +150,9 @@ def test_basic_lettered_tiles(self): start_time=now, end_time=now + timedelta(minutes=20)) ) - fn = w.save_datasets([ds], - sector_id='LCC', - source_name="TESTS", - tile_count=(3, 3), - lettered_grid=True) - # `fn` is currently the last file created - self.assertTrue(os.path.isfile(fn)) + w.save_datasets([ds], sector_id='LCC', source_name="TESTS", tile_count=(3, 3), lettered_grid=True) + all_files = glob(os.path.join(self.base_dir, 'TESTS_AII*.nc')) + self.assertEqual(len(all_files), 16) def test_lettered_tiles_no_fit(self): """Test creating a lettered grid with no data""" @@ -178,9 +170,9 @@ def test_lettered_tiles_no_fit(self): y_size=2000, area_extent=(4000000., 5000000., 5000000., 6000000.), ) - now = datetime.utcnow() + now = datetime(2018, 1, 1, 12, 0, 0) ds = DataArray( - np.linspace(0., 1., 2000000, dtype=np.float32).reshape((2000, 1000)), + da.from_array(np.linspace(0., 1., 2000000, dtype=np.float32).reshape((2000, 1000)), chunks=500), attrs=dict( name='test_ds', platform_name='PLAT', @@ -190,14 +182,10 @@ def test_lettered_tiles_no_fit(self): start_time=now, end_time=now + timedelta(minutes=20)) ) - fn = w.save_datasets([ds], - sector_id='LCC', - source_name="TESTS", - tile_count=(3, 3), - lettered_grid=True) - # `fn` is currently the last file created + w.save_datasets([ds], sector_id='LCC', source_name="TESTS", tile_count=(3, 3), lettered_grid=True) # No files created - self.assertIsNone(fn) + all_files = glob(os.path.join(self.base_dir, 'TESTS_AII*.nc')) + self.assertEqual(len(all_files), 0) def test_lettered_tiles_bad_filename(self): """Test creating a lettered grid with a bad filename""" @@ -215,9 +203,9 @@ def test_lettered_tiles_bad_filename(self): y_size=2000, area_extent=(-1000000., -1500000., 1000000., 1500000.), ) - now = datetime.utcnow() + now = datetime(2018, 1, 1, 12, 0, 0) ds = DataArray( - np.linspace(0., 1., 2000000, dtype=np.float32).reshape((2000, 1000)), + da.from_array(np.linspace(0., 1., 2000000, dtype=np.float32).reshape((2000, 1000)), chunks=500), attrs=dict( name='test_ds', platform_name='PLAT', @@ -230,14 +218,51 @@ def test_lettered_tiles_bad_filename(self): self.assertRaises(KeyError, w.save_datasets, [ds], sector_id='LCC', - source_name="TESTS", + source_name='TESTS', tile_count=(3, 3), lettered_grid=True) + def test_basic_numbered_tiles_rgb(self): + """Test creating a multiple numbered tiles with RGB""" + from satpy.writers.scmi import SCMIWriter + from xarray import DataArray + from pyresample.geometry import AreaDefinition + from pyresample.utils import proj4_str_to_dict + w = SCMIWriter(base_dir=self.base_dir, compress=True) + area_def = AreaDefinition( + 'test', + 'test', + 'test', + proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. +lat_0=25 +lat_1=25 +units=m +no_defs'), + x_size=100, + y_size=200, + area_extent=(-1000., -1500., 1000., 1500.), + ) + now = datetime(2018, 1, 1, 12, 0, 0) + ds = DataArray( + da.from_array(np.linspace(0., 1., 60000, dtype=np.float32).reshape((3, 200, 100)), chunks=50), + dims=('bands', 'y', 'x'), + coords={'bands': ['R', 'G', 'B']}, + attrs=dict( + name='test_ds', + platform_name='PLAT', + sensor='SENSOR', + units='1', + area=area_def, + start_time=now, + end_time=now + timedelta(minutes=20)) + ) + w.save_datasets([ds], sector_id='TEST', source_name="TESTS", tile_count=(3, 3)) + all_files = glob(os.path.join(self.base_dir, 'TESTS_AII*test_ds_R*.nc')) + self.assertEqual(len(all_files), 9) + all_files = glob(os.path.join(self.base_dir, 'TESTS_AII*test_ds_G*.nc')) + self.assertEqual(len(all_files), 9) + all_files = glob(os.path.join(self.base_dir, 'TESTS_AII*test_ds_B*.nc')) + self.assertEqual(len(all_files), 9) + def suite(): - """The test suite for this writer's tests. - """ + """The test suite for this writer's tests.""" loader = unittest.TestLoader() mysuite = unittest.TestSuite() mysuite.addTest(loader.loadTestsFromTestCase(TestSCMIWriter)) diff --git a/satpy/writers/scmi.py b/satpy/writers/scmi.py index 393582361d..3d9303b8c2 100644 --- a/satpy/writers/scmi.py +++ b/satpy/writers/scmi.py @@ -72,8 +72,10 @@ import numpy as np from pyproj import Proj -from satpy.writers import Writer, DecisionTree +import dask.array as da +from satpy.writers import Writer, DecisionTree, Enhancer, get_enhanced_image from pyresample.geometry import AreaDefinition +from collections import namedtuple try: from pyresample.utils import proj4_radius_parameters @@ -109,6 +111,23 @@ 'K': 'kelvin', } +TileInfo = namedtuple('TileInfo', ['tile_count', 'image_shape', 'tile_shape', + 'tile_row_offset', 'tile_column_offset', 'tile_id', + 'x', 'y', 'tile_slices', 'data_slices']) +XYFactors = namedtuple('XYFactors', ['mx', 'bx', 'my', 'by']) + + +def fix_awips_file(fn): + # hack to get files created by new NetCDF library + # versions to be read by AWIPS buggy java version + # of NetCDF + LOG.info("Modifying SCMI NetCDF file to work with AWIPS") + import h5py + h = h5py.File(fn, 'a') + if '_NCProperties' in h.attrs: + del h.attrs['_NCProperties'] + h.close() + class NumberedTileGenerator(object): def __init__(self, area_definition, @@ -124,6 +143,7 @@ def __init__(self, area_definition, # and must be stored in the file as 0, 1, 2, 3, ... # (X factor, X offset, Y factor, Y offset) self.mx, self.bx, self.my, self.by = self._get_xy_scaling_parameters() + self.xy_factors = XYFactors(self.mx, self.bx, self.my, self.by) self._tile_cache = [] def _get_tile_properties(self, tile_shape, tile_count): @@ -235,29 +255,25 @@ def _generate_tile_info(self): tmp_x = x[data_slices[1]] tmp_y = y[data_slices[0]] - tile_info = (tile_row_offset, tile_column_offset, tile_id, tmp_x, tmp_y, tile_slices, data_slices) + tile_info = TileInfo( + tc, self.image_shape, ts, + tile_row_offset, tile_column_offset, tile_id, + tmp_x, tmp_y, tile_slices, data_slices) self._tile_cache.append(tile_info) yield tile_info - def __call__(self, data, fill_value=np.nan): - ts = self.tile_shape - tmp_tile = np.ma.zeros(ts, dtype=np.float32) - tmp_tile.set_fill_value(fill_value) - tmp_tile[:] = np.ma.masked - + def __call__(self, data): if self._tile_cache: tile_infos = self._tile_cache else: tile_infos = self._generate_tile_info() for tile_info in tile_infos: - tmp_tile[tile_info[-2]] = data[tile_info[-1]] - if tmp_tile.mask.all(): - LOG.info("Tile {} contains all masked data, skipping...".format(tile_info[2])) + tile_data = data[tile_info.data_slices] + if not tile_data.size: + LOG.info("Tile {} is empty, skipping...".format(tile_info[2])) continue - - yield tile_info[:-2], tmp_tile - tmp_tile[:] = np.ma.masked + yield tile_info, tile_data class LetteredTileGenerator(NumberedTileGenerator): @@ -408,7 +424,9 @@ def _generate_tile_info(self): slice(data_x_idx_min, data_x_idx_max + 1)) data_slices = (y_slice, x_slice) - tile_info = (gy * ts[0], gx * ts[1], tile_id, tmp_x, tmp_y, tile_slices, data_slices) + tile_info = TileInfo( + self.tile_count, self.image_shape, ts, + gy * ts[0], gx * ts[1], tile_id, tmp_x, tmp_y, tile_slices, data_slices) self._tile_cache.append(tile_info) yield tile_info @@ -491,7 +509,6 @@ class NetCDFWriter(object): FUTURE: optionally add zenith and azimuth angles """ - _nc = None _kind = None # 'albedo', 'brightness_temp' _band = None _include_fgf = True @@ -503,15 +520,23 @@ class NetCDFWriter(object): fgf_x = None projection = None - def __init__(self, filename, include_fgf=True, helper=None, compress=False): - self._nc = Dataset(filename, 'w') + def __init__(self, filename, include_fgf=True, ds_info=None, compress=False): + self._nc = None + self.filename = filename self._include_fgf = include_fgf self._compress = compress - self.helper = helper + self.helper = AttributeHelper(ds_info) + self.image_data = None + + @property + def nc(self): + if self._nc is None: + self._nc = Dataset(self.filename, 'w') + return self._nc def create_dimensions(self, lines, columns): # Create Dimensions - _nc = self._nc + _nc = self.nc _nc.createDimension(self.row_dim_name, lines) _nc.createDimension(self.col_dim_name, columns) @@ -519,18 +544,20 @@ def create_variables(self, bitdepth, fill_value, scale_factor=None, add_offset=N valid_min=None, valid_max=None): fgf_coords = "%s %s" % (self.y_var_name, self.x_var_name) - self.image_data = self._nc.createVariable(self.image_var_name, - AWIPS_DATA_DTYPE, - dimensions=(self.row_dim_name, self.col_dim_name), - fill_value=fill_value, - zlib=self._compress) + self.image_data = self.nc.createVariable(self.image_var_name, + AWIPS_DATA_DTYPE, + dimensions=(self.row_dim_name, self.col_dim_name), + fill_value=fill_value, + zlib=self._compress) self.image_data.coordinates = fgf_coords self.apply_data_attributes(bitdepth, scale_factor, add_offset, valid_min=valid_min, valid_max=valid_max) if self._include_fgf: - self.fgf_y = self._nc.createVariable(self.y_var_name, 'i2', dimensions=(self.row_dim_name,), zlib=self._compress) - self.fgf_x = self._nc.createVariable(self.x_var_name, 'i2', dimensions=(self.col_dim_name,), zlib=self._compress) + self.fgf_y = self.nc.createVariable( + self.y_var_name, 'i2', dimensions=(self.row_dim_name,), zlib=self._compress) + self.fgf_x = self.nc.createVariable( + self.x_var_name, 'i2', dimensions=(self.col_dim_name,), zlib=self._compress) def apply_data_attributes(self, bitdepth, scale_factor, add_offset, valid_min=None, valid_max=None): @@ -586,19 +613,18 @@ def set_fgf(self, x, mx, bx, y, my, by, units='meters', downsample_factor=1): self.fgf_x.standard_name = "projection_x_coordinate" self.fgf_x[:] = x - def set_image_data(self, data, fill_value): - LOG.info('writing image data') + def set_image_data(self, data): + LOG.debug('writing image data') + if not hasattr(data, 'mask'): + data = np.ma.masked_array(data, np.isnan(data)) # note: autoscaling will be applied to make int16 - assert(hasattr(data, 'mask')) self.image_data[:, :] = np.require(data, dtype=np.float32) def set_projection_attrs(self, area_id, proj4_info): - """ - assign projection attributes per GRB standard - """ + """Assign projection attributes per GRB standard""" proj4_info['a'], proj4_info['b'] = proj4_radius_parameters(proj4_info) if proj4_info["proj"] == "geos": - p = self.projection = self._nc.createVariable("fixedgrid_projection", 'i4') + p = self.projection = self.nc.createVariable("fixedgrid_projection", 'i4') self.image_data.grid_mapping = "fixedgrid_projection" p.short_name = area_id p.grid_mapping_name = "geostationary" @@ -607,7 +633,7 @@ def set_projection_attrs(self, area_id, proj4_info): p.latitude_of_projection_origin = np.float32(0.0) p.longitude_of_projection_origin = np.float32(proj4_info.get('lon_0', 0.0)) # is the float32 needed? elif proj4_info["proj"] == "lcc": - p = self.projection = self._nc.createVariable("lambert_projection", 'i4') + p = self.projection = self.nc.createVariable("lambert_projection", 'i4') self.image_data.grid_mapping = "lambert_projection" p.short_name = area_id p.grid_mapping_name = "lambert_conformal_conic" @@ -615,7 +641,7 @@ def set_projection_attrs(self, area_id, proj4_info): p.longitude_of_central_meridian = proj4_info["lon_0"] p.latitude_of_projection_origin = proj4_info.get('lat_1', proj4_info['lat_0']) # Correct? elif proj4_info['proj'] == 'stere': - p = self.projection = self._nc.createVariable("polar_projection", 'i4') + p = self.projection = self.nc.createVariable("polar_projection", 'i4') self.image_data.grid_mapping = "polar_projection" p.short_name = area_id p.grid_mapping_name = "polar_stereographic" @@ -623,7 +649,7 @@ def set_projection_attrs(self, area_id, proj4_info): p.straight_vertical_longitude_from_pole = proj4_info.get("lon_0", 0.0) p.latitude_of_projection_origin = proj4_info["lat_0"] # ? elif proj4_info['proj'] == 'merc': - p = self.projection = self._nc.createVariable("mercator_projection", 'i4') + p = self.projection = self.nc.createVariable("mercator_projection", 'i4') self.image_data.grid_mapping = "mercator_projection" p.short_name = area_id p.grid_mapping_name = "mercator" @@ -639,74 +665,126 @@ def set_projection_attrs(self, area_id, proj4_info): def set_global_attrs(self, physical_element, awips_id, sector_id, creating_entity, total_tiles, total_pixels, - tile_row, tile_column, - tile_height, tile_width, creator=None): - self._nc.Conventions = "CF-1.7" + tile_row, tile_column, tile_height, tile_width, creator=None): + self.nc.Conventions = "CF-1.7" if creator is None: - self._nc.creator = "UW SSEC - CSPP Polar2Grid" + from satpy import __version__ + self.nc.creator = "SatPy Version {} - SCMI Writer".format(__version__) else: - self._nc.creator = creator - self._nc.creation_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S') + self.nc.creator = creator + self.nc.creation_time = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S') # name as it shows in the product browser (physicalElement) - self._nc.physical_element = physical_element - self._nc.satellite_id = creating_entity + self.nc.physical_element = physical_element + self.nc.satellite_id = creating_entity # identifying name to match against AWIPS common descriptions (ex. "AWIPS_product_name") - self._nc.awips_id = awips_id - self._nc.sector_id = sector_id - self._nc.tile_row_offset = tile_row - self._nc.tile_column_offset = tile_column - self._nc.product_tile_height = tile_height - self._nc.product_tile_width = tile_width - self._nc.number_product_tiles = total_tiles[0] * total_tiles[1] - self._nc.product_rows = total_pixels[0] - self._nc.product_columns = total_pixels[1] - - self.helper.apply_attributes(self._nc, SCMI_GLOBAL_ATT, '_global_') + self.nc.awips_id = awips_id + self.nc.sector_id = sector_id + self.nc.tile_row_offset = tile_row + self.nc.tile_column_offset = tile_column + self.nc.product_tile_height = tile_height + self.nc.product_tile_width = tile_width + self.nc.number_product_tiles = total_tiles[0] * total_tiles[1] + self.nc.product_rows = total_pixels[0] + self.nc.product_columns = total_pixels[1] + + self.helper.apply_attributes(self.nc, SCMI_GLOBAL_ATT, '_global_') def close(self): - self._nc.sync() - self._nc.close() - self._nc = None + if self._nc is not None: + self._nc.sync() + self._nc.close() + self._nc = None -class SCMIWriter(Writer): - def __init__(self, compress=False, fix_awips=False, **kwargs): - super(SCMIWriter, self).__init__( - self, default_config_filename="writers/scmi.yaml", **kwargs) - self.keep_intermediate = False - self.overwrite_existing = True - self.scmi_sectors = self.config['sectors'] - self.scmi_datasets = SCMIDatasetDecisionTree([self.config['datasets']]) - self.compress = compress - self.fix_awips = fix_awips - self._fill_sector_info() +class NetCDFWrapper(object): + """Object to wrap all NetCDF data-based operations in to a single call. - @classmethod - def separate_init_kwargs(cls, kwargs): - # FUTURE: Don't pass Scene.save_datasets kwargs to init and here - init_kwargs, kwargs = super(SCMIWriter, cls).separate_init_kwargs( - kwargs) - for kw in ['compress', 'fix_awips']: - if kw in kwargs: - init_kwargs[kw] = kwargs.pop(kw) + This makes it possible to do SCMI writing with dask's delayed `da.store` function. - return init_kwargs, kwargs + """ + def __init__(self, filename, sector_id, ds_info, awips_info, + xy_factors, tile_info, compress=False, fix_awips=False): + self.filename = filename + self.sector_id = sector_id + self.ds_info = ds_info + self.awips_info = awips_info + self.tile_info = tile_info + self.xy_factors = xy_factors + self.compress = compress + self.fix_awips = fix_awips - def _fill_sector_info(self): - for sector_info in self.scmi_sectors.values(): - p = Proj(sector_info['projection']) - if 'lower_left_xy' in sector_info: - sector_info['lower_left_lonlat'] = p(*sector_info['lower_left_xy'], inverse=True) - else: - sector_info['lower_left_xy'] = p(*sector_info['lower_left_lonlat']) - if 'upper_right_xy' in sector_info: - sector_info['upper_right_lonlat'] = p(*sector_info['upper_right_xy'], inverse=True) - else: - sector_info['upper_right_xy'] = p(*sector_info['upper_right_lonlat']) + def __getstate__(self): + """State for pickling.""" + args = (self.filename, self.sector_id, self.ds_info, self.awips_info, self.xy_factors, self.tile_info) + kwargs = dict( + compress=self.compress, + fix_awips=self.fix_awips, + ) + return args, kwargs + + def __setstate__(self, state): + """Restore from a pickle.""" + args, kwargs = state + self.__init__(*args, **kwargs) + + def __setitem__(self, key, data): + """Write an entire tile to a file.""" + if np.isnan(data).all(): + LOG.info("Tile {} contains all invalid data, skipping...".format(self.filename)) + return + + ds_info = self.ds_info + awips_info = self.awips_info + tile_info = self.tile_info + area_def = ds_info['area'] + LOG.debug("Scaling %s data to fit in netcdf file...", ds_info["name"]) + bit_depth = ds_info.get("bit_depth", 16) + valid_min = ds_info.get('valid_min') + if valid_min is None: + valid_min = np.nanmin(data) + valid_max = ds_info.get('valid_max') + if valid_max is None: + valid_max = np.nanmax(data) + + LOG.debug("Using product valid min {} and valid max {}".format(valid_min, valid_max)) + is_cat = 'flag_meanings' in ds_info + fills, factor, offset = self._calc_factor_offset( + data=data, bitdepth=bit_depth, min=valid_min, max=valid_max, dtype=AWIPS_DATA_DTYPE, flag_meanings=is_cat) + if is_cat: + data = data.astype(AWIPS_DATA_DTYPE) + + tmp_tile = np.empty(tile_info.tile_shape, dtype=data.dtype) + tmp_tile[:] = np.nan + tmp_tile[tile_info.tile_slices] = data + + LOG.info("Writing tile '%s' to '%s'", self.tile_info[2], self.filename) + nc = NetCDFWriter(self.filename, ds_info=self.ds_info, compress=self.compress) + LOG.debug("Creating dimensions...") + nc.create_dimensions(tmp_tile.shape[0], tmp_tile.shape[1]) + LOG.debug("Creating variables...") + nc.create_variables(bit_depth, fills[0], factor, offset) + LOG.debug("Creating global attributes...") + nc.set_global_attrs(awips_info['physical_element'], + awips_info['awips_id'], self.sector_id, + awips_info['creating_entity'], + tile_info.tile_count, tile_info.image_shape, + tile_info.tile_row_offset, tile_info.tile_column_offset, + tmp_tile.shape[0], tmp_tile.shape[1]) + LOG.debug("Creating projection attributes...") + nc.set_projection_attrs(area_def.area_id, area_def.proj_dict) + LOG.debug("Writing image data...") + np.clip(tmp_tile, valid_min, valid_max, out=tmp_tile) + nc.set_image_data(tmp_tile) + LOG.debug("Writing X/Y navigation data...") + mx, bx, my, by = self.xy_factors + nc.set_fgf(tile_info.x, mx, bx, tile_info.y, my, by, units='meters') + nc.close() + + if self.fix_awips: + fix_awips_file(self.filename) def _calc_factor_offset(self, data=None, dtype=np.int16, bitdepth=None, - min=None, max=None, num_fills=1, - flag_meanings=False): + min=None, max=None, num_fills=1, flag_meanings=False): if num_fills > 1: raise NotImplementedError("More than one fill value is not implemented yet") @@ -750,16 +828,48 @@ def _calc_factor_offset(self, data=None, dtype=np.int16, bitdepth=None, return fills, mx, bx - def _fix_awips_file(self, fn): - # hack to get files created by new NetCDF library - # versions to be read by AWIPS buggy java version - # of NetCDF - LOG.info("Modifying SCMI NetCDF file to work with AWIPS") - import h5py - h = h5py.File(fn, 'a') - if '_NCProperties' in h.attrs: - del h.attrs['_NCProperties'] - h.close() + +class SCMIWriter(Writer): + def __init__(self, compress=False, fix_awips=False, **kwargs): + super(SCMIWriter, self).__init__(default_config_filename="writers/scmi.yaml", **kwargs) + self.keep_intermediate = False + self.overwrite_existing = True + self.scmi_sectors = self.config['sectors'] + self.scmi_datasets = SCMIDatasetDecisionTree([self.config['datasets']]) + self.compress = compress + self.fix_awips = fix_awips + self._fill_sector_info() + self._enhancer = None + + @property + def enhancer(self): + """Lazy loading of enhancements only if needed.""" + if self._enhancer is None: + self._enhancer = Enhancer(ppp_config_dir=self.ppp_config_dir) + return self._enhancer + + @classmethod + def separate_init_kwargs(cls, kwargs): + # FUTURE: Don't pass Scene.save_datasets kwargs to init and here + init_kwargs, kwargs = super(SCMIWriter, cls).separate_init_kwargs( + kwargs) + for kw in ['compress', 'fix_awips']: + if kw in kwargs: + init_kwargs[kw] = kwargs.pop(kw) + + return init_kwargs, kwargs + + def _fill_sector_info(self): + for sector_info in self.scmi_sectors.values(): + p = Proj(sector_info['projection']) + if 'lower_left_xy' in sector_info: + sector_info['lower_left_lonlat'] = p(*sector_info['lower_left_xy'], inverse=True) + else: + sector_info['lower_left_xy'] = p(*sector_info['lower_left_lonlat']) + if 'upper_right_xy' in sector_info: + sector_info['upper_right_lonlat'] = p(*sector_info['upper_right_xy'], inverse=True) + else: + sector_info['upper_right_xy'] = p(*sector_info['upper_right_lonlat']) def _get_sector_info(self, sector_id, lettered_grid): try: @@ -790,7 +900,7 @@ def _get_tile_generator(self, area_def, lettered_grid, sector_id, num_subtiles, def _get_awips_info(self, ds_info, source_name=None, physical_element=None): try: - awips_info = self.scmi_datasets.find_match(**ds_info) + awips_info = self.scmi_datasets.find_match(**ds_info).copy() awips_info['awips_id'] = "AWIPS_" + ds_info['name'] if not physical_element: @@ -809,14 +919,71 @@ def _get_awips_info(self, ds_info, source_name=None, physical_element=None): def_ce = "{}-{}".format(ds_info["platform_name"].upper(), ds_info["sensor"].upper()) awips_info.setdefault('creating_entity', def_ce) return awips_info - except KeyError as e: + except KeyError: LOG.error("Could not get information on dataset from backend configuration file") raise + def _group_by_area(self, datasets): + """Group datasets by their area.""" + def _area_id(area_def): + return area_def.name + str(area_def.area_extent) + str(area_def.shape) + + # get all of the datasets stored by area + area_datasets = {} + for x in datasets: + area_id = _area_id(x.attrs['area']) + area, ds_list = area_datasets.setdefault(area_id, (x.attrs['area'], [])) + ds_list.append(x) + return area_datasets + + def _split_rgbs(self, ds): + for component in 'RGB': + band_data = ds.sel(bands=component) + band_data.attrs['name'] += '_{}'.format(component) + band_data.attrs['valid_min'] = 0.0 + band_data.attrs['valid_max'] = 1.0 + yield band_data + + def _enhance_and_split_rgbs(self, datasets): + new_datasets = [] + for ds in datasets: + if ds.ndim == 2: + new_datasets.append(ds) + continue + elif ds.ndim > 3 or ds.ndim < 1 or (ds.ndim == 3 and 'bands' not in ds.coords): + LOG.error("Can't save datasets with more or less than 2 dimensions " + "that aren't RGBs to SCMI format: {}".format(ds.name)) + else: + # this is an RGB + img = get_enhanced_image(ds.squeeze(), self.enhancer) + res_data = img.finalize(fill_value=0, dtype=np.float32)[0] + new_datasets.extend(self._split_rgbs(res_data)) + + return new_datasets + def save_dataset(self, dataset, **kwargs): LOG.warning("For best performance use `save_datasets`") return self.save_datasets([dataset], **kwargs) + def get_filename(self, area_def, tile_info, sector_id, **kwargs): + # format the filename + kwargs["start_time"] += timedelta(minutes=int(os.environ.get("DEBUG_TIME_SHIFT", 0))) + return super(SCMIWriter, self).get_filename( + area_id=area_def.area_id, + rows=area_def.y_size, + columns=area_def.x_size, + sector_id=sector_id, + tile_id=tile_info.tile_id, + **kwargs) + + def check_tile_exists(self, output_filename): + if os.path.isfile(output_filename): + if not self.overwrite_existing: + LOG.error("AWIPS file already exists: %s", output_filename) + raise RuntimeError("AWIPS file already exists: %s" % (output_filename,)) + else: + LOG.warning("AWIPS file already exists, will overwrite: %s", output_filename) + def save_datasets(self, datasets, sector_id=None, source_name=None, filename=None, tile_count=(1, 1), tile_size=None, @@ -824,157 +991,29 @@ def save_datasets(self, datasets, sector_id=None, compute=True, **kwargs): if sector_id is None: raise TypeError("Keyword 'sector_id' is required") - if not compute: - import warnings - warnings.warn("SCMI Writer does not support delayed computing " - "yet.") - - def _area_id(area_def): - return area_def.name + str(area_def.area_extent) + str(area_def.shape) - # get all of the datasets stored by area - area_datasets = {} - for x in datasets: - area_id = _area_id(x.attrs['area']) - area, ds_list = area_datasets.setdefault(area_id, (x.attrs['area'], [])) - ds_list.append(x) - output_filenames = [] - dtype = AWIPS_DATA_DTYPE - fill_value = np.nan + area_datasets = self._group_by_area(datasets) + sources_targets = [] for area_id, (area_def, ds_list) in area_datasets.items(): tile_gen = self._get_tile_generator(area_def, lettered_grid, sector_id, num_subtiles, tile_size, tile_count) - for dataset in ds_list: - pkwargs = {} - ds_info = dataset.attrs.copy() - LOG.info("Writing product %s to AWIPS SCMI NetCDF file", ds_info["name"]) - if isinstance(dataset, np.ma.MaskedArray): - data = dataset - else: - # FIXME: Handle data better by using `da.store` or move - # netcdf creation/storing to a dask delayed object - mask = dataset.isnull() - data = np.ma.masked_array(dataset.values, mask=mask, copy=False) - - pkwargs['awips_info'] = self._get_awips_info(ds_info, source_name=source_name) - pkwargs['attr_helper'] = AttributeHelper(ds_info) - - LOG.debug("Scaling %s data to fit in netcdf file...", ds_info["name"]) - bit_depth = ds_info.setdefault("bit_depth", 16) - valid_min = ds_info.get('valid_min') - if valid_min is None: - valid_min = np.nanmin(data) - valid_max = ds_info.get('valid_max') - if valid_max is None: - valid_max = np.nanmax(data) - pkwargs['valid_min'] = valid_min - pkwargs['valid_max'] = valid_max - pkwargs['bit_depth'] = bit_depth - - LOG.debug("Using product valid min {} and valid max {}".format(valid_min, valid_max)) - fills, factor, offset = self._calc_factor_offset( - data=data, - bitdepth=bit_depth, - min=valid_min, - max=valid_max, - dtype=dtype, - flag_meanings='flag_meanings' in ds_info) - pkwargs['fills'] = fills - pkwargs['factor'] = factor - pkwargs['offset'] = offset - if 'flag_meanings' in ds_info: - pkwargs['data'] = data.astype(dtype) - else: - pkwargs['data'] = data - - for (trow, tcol, tile_id, tmp_x, tmp_y), tmp_tile in tile_gen(data, fill_value=fill_value): - try: - fn = self.create_tile_output( - dataset.attrs, sector_id, - trow, tcol, tile_id, tmp_x, tmp_y, tmp_tile, - tile_gen.tile_count, tile_gen.image_shape, - tile_gen.mx, tile_gen.bx, tile_gen.my, tile_gen.by, - filename, **pkwargs) - if fn is None: - if lettered_grid: - LOG.warning("Data did not fit in to any lettered tile") - raise RuntimeError("No SCMI tiles were created") - output_filenames.append(fn) - except (RuntimeError, KeyError, AttributeError): - LOG.error("Could not create output for '%s'", ds_info['name']) - LOG.debug("Writer exception: ", exc_info=True) - raise - - return output_filenames[-1] if output_filenames else None - - def create_tile_output(self, ds_info, sector_id, - trow, tcol, tile_id, tmp_x, tmp_y, tmp_tile, - tile_count, image_shape, - mx, bx, my, by, - filename, - awips_info, attr_helper, - fills, factor, offset, valid_min, valid_max, bit_depth, **kwargs): - # Create the netcdf file - area_def = ds_info['area'] - created_files = [] - try: - if filename is None: - # format the filename - of_kwargs = ds_info.copy() - of_kwargs["start_time"] += timedelta(minutes=int(os.environ.get("DEBUG_TIME_SHIFT", 0))) - output_filename = self.get_filename( - area_id=area_def.area_id, - rows=area_def.y_size, - columns=area_def.x_size, - source_name=awips_info['source_name'], - sector_id=sector_id, - tile_id=tile_id, - **of_kwargs - ) - else: - output_filename = filename - if os.path.isfile(output_filename): - if not self.overwrite_existing: - LOG.error("AWIPS file already exists: %s", output_filename) - raise RuntimeError("AWIPS file already exists: %s" % (output_filename,)) - else: - LOG.warning("AWIPS file already exists, will overwrite: %s", output_filename) - created_files.append(output_filename) - - LOG.info("Writing tile '%s' to '%s'", tile_id, output_filename) - - nc = NetCDFWriter(output_filename, helper=attr_helper, - compress=self.compress) - LOG.debug("Creating dimensions...") - nc.create_dimensions(tmp_tile.shape[0], tmp_tile.shape[1]) - LOG.debug("Creating variables...") - nc.create_variables(bit_depth, fills[0], factor, offset) - LOG.debug("Creating global attributes...") - nc.set_global_attrs(awips_info['physical_element'], - awips_info['awips_id'], sector_id, - awips_info['creating_entity'], - tile_count, image_shape, - trow, tcol, tmp_tile.shape[0], tmp_tile.shape[1]) - LOG.debug("Creating projection attributes...") - nc.set_projection_attrs(area_def.area_id, area_def.proj_dict) - LOG.debug("Writing image data...") - np.clip(tmp_tile, valid_min, valid_max, out=tmp_tile) - nc.set_image_data(tmp_tile, fills[0]) - LOG.debug("Writing X/Y navigation data...") - nc.set_fgf(tmp_x, mx, bx, - tmp_y, my, by, units='meters') - nc.close() - - if self.fix_awips: - self._fix_awips_file(output_filename) - except (KeyError, AttributeError, RuntimeError): - last_fn = created_files[-1] if created_files else "N/A" - LOG.error("Error while filling in NC file with data: %s", last_fn) - for fn in created_files: - if not self.keep_intermediate and os.path.isfile(fn): - os.remove(fn) - raise - - return created_files[-1] if created_files else None + for dataset in self._enhance_and_split_rgbs(ds_list): + LOG.info("Preparing product %s to be written to AWIPS SCMI NetCDF file", dataset.attrs["name"]) + awips_info = self._get_awips_info(dataset.attrs, source_name=source_name) + for tile_info, tmp_tile in tile_gen(dataset): + # make sure this entire tile is loaded as one single array + tmp_tile.data = tmp_tile.data.rechunk(tmp_tile.shape) + output_filename = filename or self.get_filename(area_def, tile_info, sector_id, + source_name=awips_info['source_name'], + **dataset.attrs) + self.check_tile_exists(output_filename) + nc_wrapper = NetCDFWrapper(output_filename, sector_id, dataset.attrs, awips_info, + tile_gen.xy_factors, tile_info, + compress=self.compress, fix_awips=self.fix_awips) + sources_targets.append((tmp_tile.data, nc_wrapper)) + + if compute and sources_targets: + return da.store(*zip(*sources_targets)) + return sources_targets def _create_debug_array(sector_info, num_subtiles, font_path='Verdana.ttf'): From 0cf4fe1f6d93abbab9c848e6eb04929b1068aae3 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Tue, 30 Oct 2018 20:08:30 -0500 Subject: [PATCH 2/3] Fix styling error in test_scmi tests --- satpy/tests/writer_tests/test_scmi.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/satpy/tests/writer_tests/test_scmi.py b/satpy/tests/writer_tests/test_scmi.py index 264960fb9f..909d4b2926 100644 --- a/satpy/tests/writer_tests/test_scmi.py +++ b/satpy/tests/writer_tests/test_scmi.py @@ -68,7 +68,8 @@ def test_basic_numbered_1_tile(self): 'test', 'test', 'test', - proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. +lat_0=25 +lat_1=25 +units=m +no_defs'), + proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. ' + '+lat_0=25 +lat_1=25 +units=m +no_defs'), x_size=100, y_size=200, area_extent=(-1000., -1500., 1000., 1500.), @@ -101,7 +102,8 @@ def test_basic_numbered_tiles(self): 'test', 'test', 'test', - proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. +lat_0=25 +lat_1=25 +units=m +no_defs'), + proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. ' + '+lat_0=25 +lat_1=25 +units=m +no_defs'), x_size=100, y_size=200, area_extent=(-1000., -1500., 1000., 1500.), @@ -133,7 +135,8 @@ def test_basic_lettered_tiles(self): 'test', 'test', 'test', - proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. +lat_0=25 +lat_1=25 +units=m +no_defs'), + proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. ' + '+lat_0=25 +lat_1=25 +units=m +no_defs'), x_size=1000, y_size=2000, area_extent=(-1000000., -1500000., 1000000., 1500000.), @@ -165,7 +168,8 @@ def test_lettered_tiles_no_fit(self): 'test', 'test', 'test', - proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. +lat_0=25 +lat_1=25 +units=m +no_defs'), + proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. ' + '+lat_0=25 +lat_1=25 +units=m +no_defs'), x_size=1000, y_size=2000, area_extent=(4000000., 5000000., 5000000., 6000000.), @@ -198,7 +202,8 @@ def test_lettered_tiles_bad_filename(self): 'test', 'test', 'test', - proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. +lat_0=25 +lat_1=25 +units=m +no_defs'), + proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. ' + '+lat_0=25 +lat_1=25 +units=m +no_defs'), x_size=1000, y_size=2000, area_extent=(-1000000., -1500000., 1000000., 1500000.), @@ -233,7 +238,8 @@ def test_basic_numbered_tiles_rgb(self): 'test', 'test', 'test', - proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. +lat_0=25 +lat_1=25 +units=m +no_defs'), + proj_dict=proj4_str_to_dict('+proj=lcc +datum=WGS84 +ellps=WGS84 +lon_0=-95. ' + '+lat_0=25 +lat_1=25 +units=m +no_defs'), x_size=100, y_size=200, area_extent=(-1000., -1500., 1000., 1500.), From 0537a4dacb6753ebda13616efc5b8433f60ce6ec Mon Sep 17 00:00:00 2001 From: David Hoese Date: Wed, 31 Oct 2018 12:14:11 -0500 Subject: [PATCH 3/3] Remove unnecessary state methods and locking in SCMI writer --- satpy/writers/scmi.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/satpy/writers/scmi.py b/satpy/writers/scmi.py index 3d9303b8c2..4aa4babee8 100644 --- a/satpy/writers/scmi.py +++ b/satpy/writers/scmi.py @@ -713,20 +713,6 @@ def __init__(self, filename, sector_id, ds_info, awips_info, self.compress = compress self.fix_awips = fix_awips - def __getstate__(self): - """State for pickling.""" - args = (self.filename, self.sector_id, self.ds_info, self.awips_info, self.xy_factors, self.tile_info) - kwargs = dict( - compress=self.compress, - fix_awips=self.fix_awips, - ) - return args, kwargs - - def __setstate__(self, state): - """Restore from a pickle.""" - args, kwargs = state - self.__init__(*args, **kwargs) - def __setitem__(self, key, data): """Write an entire tile to a file.""" if np.isnan(data).all(): @@ -1012,7 +998,8 @@ def save_datasets(self, datasets, sector_id=None, sources_targets.append((tmp_tile.data, nc_wrapper)) if compute and sources_targets: - return da.store(*zip(*sources_targets)) + # the NetCDF creation is per-file so we don't need to lock + return da.store(*zip(*sources_targets), lock=False) return sources_targets