Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make viirs-compact datasets compatible with dask distributed #1546

Merged
merged 5 commits into from Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
237 changes: 114 additions & 123 deletions satpy/readers/viirs_compact.py
Expand Up @@ -30,35 +30,35 @@
"""

import logging
from contextlib import suppress
from datetime import datetime, timedelta

import dask.array as da
import h5py
import numpy as np
import xarray as xr
import dask.array as da

from satpy import CHUNK_SIZE
from satpy.readers.file_handlers import BaseFileHandler
from satpy.readers.utils import np2str
from satpy.utils import angle2xyz, lonlat2xyz, xyz2angle, xyz2lonlat
from satpy import CHUNK_SIZE

chans_dict = {"M01": "M1",
"M02": "M2",
"M03": "M3",
"M04": "M4",
"M05": "M5",
"M06": "M6",
"M07": "M7",
"M08": "M8",
"M09": "M9",
"M10": "M10",
"M11": "M11",
"M12": "M12",
"M13": "M13",
"M14": "M14",
"M15": "M15",
"M16": "M16",
"DNB": "DNB"}
_channels_dict = {"M01": "M1",
"M02": "M2",
"M03": "M3",
"M04": "M4",
"M05": "M5",
"M06": "M6",
"M07": "M7",
"M08": "M8",
"M09": "M9",
"M10": "M10",
"M11": "M11",
"M12": "M12",
"M13": "M13",
"M14": "M14",
"M15": "M15",
"M16": "M16",
"DNB": "DNB"}

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,40 +99,35 @@ def __init__(self, filename, filename_info, filetype_info):
or (max(abs(self.min_lat), abs(self.max_lat)) > 60))

self.scans = self.h5f["All_Data"]["NumberOfScans"][0]
self.geostuff = self.h5f["All_Data"]['VIIRS-%s-GEO_All' % self.ch_type]
self.geography = self.h5f["All_Data"]['VIIRS-%s-GEO_All' % self.ch_type]

for key in self.h5f["All_Data"].keys():
if key.startswith("VIIRS") and key.endswith("SDR_All"):
channel = key.split('-')[1]
break

# FIXME: this supposes there is only one tiepoint zone in the
# track direction
self.scan_size = self.h5f["All_Data/VIIRS-%s-SDR_All" %
channel].attrs["TiePointZoneSizeTrack"].item()
self.track_offset = self.h5f["All_Data/VIIRS-%s-SDR_All" %
channel].attrs["PixelOffsetTrack"]
self.scan_offset = self.h5f["All_Data/VIIRS-%s-SDR_All" %
channel].attrs["PixelOffsetScan"]
# This supposes there is only one tiepoint zone in the track direction.
channel_path = f"All_Data/VIIRS-{channel}-SDR_All"
self.scan_size = self.h5f[channel_path].attrs["TiePointZoneSizeTrack"].item()
self.track_offset = self.h5f[channel_path].attrs["PixelOffsetTrack"][()]
self.scan_offset = self.h5f[channel_path].attrs["PixelOffsetScan"][()]

try:
self.group_locations = self.geostuff[
"TiePointZoneGroupLocationScanCompact"][()]
self.group_locations = self.geography["TiePointZoneGroupLocationScanCompact"][()]
except KeyError:
self.group_locations = [0]

self.tpz_sizes = da.from_array(self.h5f["All_Data/VIIRS-%s-SDR_All" % channel].attrs["TiePointZoneSizeScan"],
chunks=1)
self.tpz_sizes = da.from_array(self.h5f[channel_path].attrs["TiePointZoneSizeScan"], chunks=1)
if len(self.tpz_sizes.shape) == 2:
if self.tpz_sizes.shape[1] != 1:
raise NotImplementedError("Can't handle 2 dimensional tiepoint zones.")
self.tpz_sizes = self.tpz_sizes.squeeze(1)
self.nb_tpzs = self.geostuff["NumberOfTiePointZonesScan"]
self.c_align = da.from_array(self.geostuff["AlignmentCoefficient"],
chunks=tuple(self.nb_tpzs))
self.c_exp = da.from_array(self.geostuff["ExpansionCoefficient"],
chunks=tuple(self.nb_tpzs))
self.nb_tpzs = da.from_array(self.nb_tpzs, chunks=1)
self.nb_tiepoint_zones = self.geography["NumberOfTiePointZonesScan"][()]
mraspaud marked this conversation as resolved.
Show resolved Hide resolved
self.c_align = da.from_array(self.geography["AlignmentCoefficient"],
chunks=tuple(self.nb_tiepoint_zones))
self.c_exp = da.from_array(self.geography["ExpansionCoefficient"],
chunks=tuple(self.nb_tiepoint_zones))
self.nb_tiepoint_zones = da.from_array(self.nb_tiepoint_zones, chunks=1)
self._expansion_coefs = None

self.cache = {}
Expand All @@ -144,15 +139,13 @@ def __init__(self, filename, filename_info, filetype_info):

def __del__(self):
"""Close file handlers when we are done."""
try:
with suppress(OSError):
self.h5f.close()
except OSError:
pass

def get_dataset(self, key, info):
"""Load a dataset."""
logger.debug('Reading %s.', key['name'])
if key['name'] in chans_dict:
if key['name'] in _channels_dict:
m_data = self.read_dataset(key, info)
else:
m_data = self.read_geo(key, info)
Expand All @@ -164,10 +157,8 @@ def get_bounding_box(self):
"""Get the bounding box of the data."""
for key in self.h5f["Data_Products"].keys():
if key.startswith("VIIRS") and key.endswith("GEO"):
lats = self.h5f["Data_Products"][key][
key + '_Gran_0'].attrs['G-Ring_Latitude']
lons = self.h5f["Data_Products"][key][
key + '_Gran_0'].attrs['G-Ring_Longitude']
lats = self.h5f["Data_Products"][key][key + '_Gran_0'].attrs['G-Ring_Latitude'][()]
lons = self.h5f["Data_Products"][key][key + '_Gran_0'].attrs['G-Ring_Longitude'][()]
break
else:
raise KeyError('Cannot find bounding coordinates!')
Expand Down Expand Up @@ -214,8 +205,6 @@ def read_geo(self, key, info):
attrs=self.mda, dims=('y', 'x'))

if info.get('standard_name') in ['latitude', 'longitude']:
if self.lons is None or self.lats is None:
self.lons, self.lats = self.navigate()
mda = self.mda.copy()
mda.update(info)
if info['standard_name'] == 'longitude':
Expand All @@ -226,13 +215,13 @@ def read_geo(self, key, info):
if key['name'] == 'dnb_moon_illumination_fraction':
mda = self.mda.copy()
mda.update(info)
return xr.DataArray(da.from_array(self.geostuff["MoonIllumFraction"]),
return xr.DataArray(da.from_array(self.geography["MoonIllumFraction"]),
attrs=info)

def read_dataset(self, dataset_key, info):
"""Read a dataset."""
h5f = self.h5f
channel = chans_dict[dataset_key['name']]
channel = _channels_dict[dataset_key['name']]
chan_dict = dict([(key.split("-")[1], key)
for key in h5f["All_Data"].keys()
if key.startswith("VIIRS")])
Expand All @@ -245,13 +234,6 @@ def read_dataset(self, dataset_key, info):
h5attrs = h5rads.attrs
scans = h5f["All_Data"]["NumberOfScans"][0]
rads = rads[:scans * 16, :]
# if channel in ("M9", ):
# arr = rads[:scans * 16, :].astype(np.float32)
# arr[arr > 65526] = np.nan
# arr = np.ma.masked_array(arr, mask=arr_mask)
# else:
# arr = np.ma.masked_greater(rads[:scans * 16, :].astype(np.float32),
# 65526)
rads = rads.where(rads <= 65526)
try:
rads = xr.where(rads <= h5attrs['Threshold'],
Expand Down Expand Up @@ -299,79 +281,38 @@ def read_dataset(self, dataset_key, info):
rads.attrs['units'] = unit
return rads

def expand(self, data, coefs):
"""Perform the expansion in numpy domain."""
data = data.reshape(data.shape[:-1])

coefs = coefs.reshape(self.scans, self.scan_size, data.shape[1] - 1, -1, 4)

coef_a = coefs[:, :, :, :, 0]
coef_b = coefs[:, :, :, :, 1]
coef_c = coefs[:, :, :, :, 2]
coef_d = coefs[:, :, :, :, 3]

data_a = data[:self.scans * 2:2, np.newaxis, :-1, np.newaxis]
data_b = data[:self.scans * 2:2, np.newaxis, 1:, np.newaxis]
data_c = data[1:self.scans * 2:2, np.newaxis, 1:, np.newaxis]
data_d = data[1:self.scans * 2:2, np.newaxis, :-1, np.newaxis]

fdata = (coef_a * data_a + coef_b * data_b + coef_d * data_d + coef_c * data_c)

return fdata.reshape(self.scans * self.scan_size, -1)

def expand_angle_and_nav(self, arrays):
"""Expand angle and navigation datasets."""
res = []
for array in arrays:
res.append(da.map_blocks(self.expand, array[:, :, np.newaxis], self.expansion_coefs,
res.append(da.map_blocks(expand, array[:, :, np.newaxis], self.expansion_coefs,
scans=self.scans, scan_size=self.scan_size,
dtype=array.dtype, drop_axis=2, chunks=self.expansion_coefs.chunks[:-1]))
return res

def get_coefs(self, c_align, c_exp, tpz_size, nb_tpz, v_track):
"""Compute the coeffs in numpy domain."""
nties = nb_tpz.item()
tpz_size = tpz_size.item()
v_scan = (np.arange(nties * tpz_size) % tpz_size + self.scan_offset) / tpz_size
s_scan, s_track = np.meshgrid(v_scan, v_track)
s_track = s_track.reshape(self.scans, self.scan_size, nties, tpz_size)
s_scan = s_scan.reshape(self.scans, self.scan_size, nties, tpz_size)

c_align = c_align[np.newaxis, np.newaxis, :, np.newaxis]
c_exp = c_exp[np.newaxis, np.newaxis, :, np.newaxis]

a_scan = s_scan + s_scan * (1 - s_scan) * c_exp + s_track * (
1 - s_track) * c_align
a_track = s_track
coef_a = (1 - a_track) * (1 - a_scan)
coef_b = (1 - a_track) * a_scan
coef_d = a_track * (1 - a_scan)
coef_c = a_track * a_scan
res = np.stack([coef_a, coef_b, coef_c, coef_d], axis=4).reshape(self.scans * self.scan_size, -1, 4)
return res

@property
def expansion_coefs(self):
"""Compute the expansion coefficients."""
if self._expansion_coefs is not None:
return self._expansion_coefs
v_track = (np.arange(self.scans * self.scan_size) % self.scan_size + self.track_offset) / self.scan_size
self.tpz_sizes = self.tpz_sizes.persist()
self.nb_tpzs = self.nb_tpzs.persist()
col_chunks = (self.tpz_sizes * self.nb_tpzs).compute()
self._expansion_coefs = da.map_blocks(self.get_coefs, self.c_align, self.c_exp, self.tpz_sizes, self.nb_tpzs,
dtype=np.float64, v_track=v_track, new_axis=[0, 2],
chunks=(self.scans * self.scan_size,
tuple(col_chunks), 4))
self.nb_tiepoint_zones = self.nb_tiepoint_zones.persist()
col_chunks = (self.tpz_sizes * self.nb_tiepoint_zones).compute()
self._expansion_coefs = da.map_blocks(get_coefs, self.c_align, self.c_exp, self.tpz_sizes,
self.nb_tiepoint_zones,
v_track=v_track, scans=self.scans, scan_size=self.scan_size,
scan_offset=self.scan_offset,
dtype=np.float64, new_axis=[0, 2],
chunks=(self.scans * self.scan_size, tuple(col_chunks), 4))

return self._expansion_coefs

def navigate(self):
"""Generate the navigation datasets."""
shape = self.geostuff['Longitude'].shape
hchunks = (self.nb_tpzs + 1).compute()
chunks = (shape[0], tuple(hchunks))
lon = da.from_array(self.geostuff["Longitude"], chunks=chunks)
lat = da.from_array(self.geostuff["Latitude"], chunks=chunks)
chunks = self._get_geographical_chunks()
lon = da.from_array(self.geography["Longitude"], chunks=chunks)
lat = da.from_array(self.geography["Latitude"], chunks=chunks)
if self.switch_to_cart:
arrays = lonlat2xyz(lon, lat)
else:
Expand All @@ -383,14 +324,18 @@ def navigate(self):

return expanded

def _get_geographical_chunks(self):
shape = self.geography['Longitude'].shape
horizontal_chunks = (self.nb_tiepoint_zones + 1).compute()
chunks = (shape[0], tuple(horizontal_chunks))
return chunks

def angles(self, azi_name, zen_name):
"""Generate the angle datasets."""
shape = self.geostuff['Longitude'].shape
hchunks = (self.nb_tpzs + 1).compute()
chunks = (shape[0], tuple(hchunks))
chunks = self._get_geographical_chunks()

azi = self.geostuff[azi_name]
zen = self.geostuff[zen_name]
azi = self.geography[azi_name]
zen = self.geography[zen_name]

switch_to_cart = ((np.max(azi) - np.min(azi) > 5)
or (np.min(zen) < 10)
Expand Down Expand Up @@ -433,6 +378,56 @@ def convert_to_angles(x, y, z):
return azi, zen


def get_coefs(c_align, c_exp, tpz_size, nb_tpz, v_track, scans, scan_size, scan_offset):
"""Compute the coeffs in numpy domain."""
nties = nb_tpz.item()
tpz_size = tpz_size.item()
v_scan = (np.arange(nties * tpz_size) % tpz_size + scan_offset) / tpz_size
s_scan, s_track = np.meshgrid(v_scan, v_track)
s_track = s_track.reshape(scans, scan_size, nties, tpz_size)
s_scan = s_scan.reshape(scans, scan_size, nties, tpz_size)

c_align = c_align[np.newaxis, np.newaxis, :, np.newaxis]
c_exp = c_exp[np.newaxis, np.newaxis, :, np.newaxis]

a_scan = s_scan + s_scan * (1 - s_scan) * c_exp + s_track * (
1 - s_track) * c_align
a_track = s_track
coef_a = (1 - a_track) * (1 - a_scan)
coef_b = (1 - a_track) * a_scan
coef_d = a_track * (1 - a_scan)
coef_c = a_track * a_scan
res = np.stack([coef_a, coef_b, coef_c, coef_d], axis=4).reshape(scans * scan_size, -1, 4)
return res


def expand(data, coefs, scans, scan_size):
"""Perform the expansion in numpy domain."""
data = data.reshape(data.shape[:-1])

coefs = coefs.reshape(scans, scan_size, data.shape[1] - 1, -1, 4)

coef_a = coefs[:, :, :, :, 0]
coef_b = coefs[:, :, :, :, 1]
coef_c = coefs[:, :, :, :, 2]
coef_d = coefs[:, :, :, :, 3]

corner_coefficients = (coef_a, coef_b, coef_c, coef_d)
fdata = _interpolate_data(data, corner_coefficients, scans)
return fdata.reshape(scans * scan_size, -1)


def _interpolate_data(data, corner_coefficients, scans):
"""Interpolate the data using the provided coefficients."""
coef_a, coef_b, coef_c, coef_d = corner_coefficients
data_a = data[:scans * 2:2, np.newaxis, :-1, np.newaxis]
data_b = data[:scans * 2:2, np.newaxis, 1:, np.newaxis]
data_c = data[1:scans * 2:2, np.newaxis, 1:, np.newaxis]
data_d = data[1:scans * 2:2, np.newaxis, :-1, np.newaxis]
fdata = (coef_a * data_a + coef_b * data_b + coef_d * data_d + coef_c * data_c)
return fdata


def expand_arrays(arrays,
scans,
c_align,
Expand Down Expand Up @@ -460,12 +455,8 @@ def expand_arrays(arrays,
coef_b = (1 - a_track) * a_scan
coef_d = a_track * (1 - a_scan)
coef_c = a_track * a_scan
corner_coefficients = (coef_a, coef_b, coef_c, coef_d)
for data in arrays:
data_a = data[:scans * 2:2, np.newaxis, :-1, np.newaxis]
data_b = data[:scans * 2:2, np.newaxis, 1:, np.newaxis]
data_c = data[1:scans * 2:2, np.newaxis, 1:, np.newaxis]
data_d = data[1:scans * 2:2, np.newaxis, :-1, np.newaxis]
fdata = (coef_a * data_a + coef_b * data_b
+ coef_d * data_d + coef_c * data_c)
fdata = _interpolate_data(data, corner_coefficients, scans)
expanded.append(fdata.reshape(scans * scan_size, nties * tpz_size))
return expanded