Skip to content

Commit

Permalink
Merge pull request #33 from pytroll/feature_dask_atm_correction
Browse files Browse the repository at this point in the history
Fix Rayleigh corrector to work with dask
  • Loading branch information
adybbroe committed Apr 17, 2018
2 parents 206ac06 + e583286 commit f6c2ded
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 89 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Expand Up @@ -13,6 +13,7 @@ matrix:
install:
- if [[ $TRAVIS_PYTHON_VERSION == "2.7" ]]; then travis_wait mvn; pip install "scipy>=0.14"; fi
- pip install -r requirements.txt
- pip install dask[array] # optional
- pip install .
- pip install coveralls
addons:
Expand Down
2 changes: 1 addition & 1 deletion appveyor.yml
Expand Up @@ -62,7 +62,7 @@ install:
# target Python version and architecture
- "conda update --yes conda"
- "conda config --add channels conda-forge"
- "conda create -q --yes -n test python=%PYTHON_VERSION% matplotlib tqdm six pyyaml xlrd sphinx h5py scipy pandas requests appdirs"
- "conda create -q --yes -n test python=%PYTHON_VERSION% matplotlib tqdm six pyyaml xlrd sphinx h5py scipy pandas requests appdirs dask toolz"
- "activate test"
- "pip install coveralls"
- "pip install mock"
Expand Down
204 changes: 128 additions & 76 deletions pyspectral/rayleigh.py
Expand Up @@ -26,14 +26,31 @@
"""

import logging
import os
import time
import logging
from six import integer_types

import h5py
import numpy as np
from scipy.interpolate import interpn

try:
from dask.array import (where, zeros, clip, rad2deg, deg2rad, cos, arccos,
atleast_2d, Array, map_blocks, from_array)
HAVE_DASK = True
try:
# use serializable h5py wrapper to make sure files are closed properly
import h5pickle as h5py
except ImportError:
pass
except ImportError:
from numpy import where, zeros, clip, rad2deg, deg2rad, cos, arccos, atleast_2d
map_blocks = None
from_array = None
Array = None
HAVE_DASK = False

from geotiepoints.multilinear import MultilinearInterpolator
from pyspectral.rsr_reader import RelativeSpectralResponse
from pyspectral.utils import (INSTRUMENTS, RAYLEIGH_LUT_DIRS,
AEROSOL_TYPES, ATMOSPHERES,
Expand All @@ -43,31 +60,25 @@

LOG = logging.getLogger(__name__)

WITH_CYTHON = True
try:
from geotiepoints.multilinear import MultilinearInterpolator
except ImportError:
LOG.warning(
"Couldn't import fast multilinear interpolation with Cython.")
LOG.warning("Check your geotiepoints installation!")
WITH_CYTHON = False


class BandFrequencyOutOfRange(Exception):
class BandFrequencyOutOfRange(ValueError):

"""Exception when the band frequency is out of the visible range"""

pass


class Rayleigh(object):

"""Container for the atmospheric correction of satellite imager short
wave bands. Removing background contributions of Rayleigh scattering of
"""Container for the atmospheric correction of satellite imager bands.
This class removes background contributions of Rayleigh scattering of
molecules and Mie scattering and absorption by aerosols.
"""

def __init__(self, platform_name, sensor, **kwargs):
"""Initialize class and determine LUT to use."""
self.platform_name = platform_name
self.sensor = sensor
self.coeff_filename = None
Expand Down Expand Up @@ -115,10 +126,14 @@ def __init__(self, platform_name, sensor, **kwargs):
str(self.reflectance_lut_filename))

LOG.debug('LUT filename: %s', str(self.reflectance_lut_filename))
self._rayl = None
self._wvl_coord = None
self._azid_coord = None
self._satz_sec_coord = None
self._sunz_sec_coord = None

def get_effective_wavelength(self, bandname):
"""Get the effective wavelength with Rayleigh scattering in mind"""

try:
rsr = RelativeSpectralResponse(self.platform_name, self.sensor)
except IOError:
Expand Down Expand Up @@ -151,75 +166,99 @@ def get_reflectance_lut(self):
secant, azimuth difference angle, and sun zenith secant
"""

return get_reflectance_lut(self.reflectance_lut_filename)
if self._rayl is None:
lut_vars = get_reflectance_lut(self.reflectance_lut_filename)
self._rayl = lut_vars[0]
self._wvl_coord = lut_vars[1]
self._azid_coord = lut_vars[2]
self._satz_sec_coord = lut_vars[3]
self._sunz_sec_coord = lut_vars[4]
return self._rayl, self._wvl_coord, self._azid_coord,\
self._satz_sec_coord, self._sunz_sec_coord

def get_reflectance(self, sun_zenith, sat_zenith, azidiff, bandname,
redband=None):
"""Get the reflectance from the three sun-sat angles."""
# Get wavelength in nm for band:
wvl = self.get_effective_wavelength(bandname) * 1000.0
rayl, wvl_coord, azid_coord, satz_sec_coord, sunz_sec_coord = self.get_reflectance_lut()

clip_angle = np.rad2deg(np.arccos(1. / sunz_sec_coord.max()))
sun_zenith = np.clip(np.asarray(sun_zenith), 0, clip_angle)
sunzsec = 1. / np.cos(np.deg2rad(sun_zenith))
clip_angle = np.rad2deg(np.arccos(1. / satz_sec_coord.max()))
sat_zenith = np.clip(np.asarray(sat_zenith), 0, clip_angle)
satzsec = 1. / np.cos(np.deg2rad(np.asarray(sat_zenith)))

rayl, wvl_coord, azid_coord, satz_sec_coord, sunz_sec_coord = \
self.get_reflectance_lut()

# force dask arrays
compute = False
if HAVE_DASK and not isinstance(sun_zenith, Array):
compute = True
sun_zenith = from_array(sun_zenith, chunks=sun_zenith.shape)
sat_zenith = from_array(sat_zenith, chunks=sat_zenith.shape)
azidiff = from_array(azidiff, chunks=azidiff.shape)
if redband is not None:
redband = from_array(redband, chunks=redband.shape)

clip_angle = rad2deg(arccos(1. / sunz_sec_coord.max()))
sun_zenith = clip(sun_zenith, 0, clip_angle)
sunzsec = 1. / cos(deg2rad(sun_zenith))
clip_angle = rad2deg(arccos(1. / satz_sec_coord.max()))
sat_zenith = clip(sat_zenith, 0, clip_angle)
satzsec = 1. / cos(deg2rad(sat_zenith))
shape = sun_zenith.shape

if not(wvl_coord.min() < wvl < wvl_coord.max()):
LOG.warning(
"Effective wavelength for band %s outside 400-800 nm range!", str(bandname))
"Effective wavelength for band %s outside 400-800 nm range!",
str(bandname))
LOG.info(
"Set the rayleigh/aerosol reflectance contribution to zero!")
return np.zeros(shape)
if HAVE_DASK:
chunks = sun_zenith.chunks if redband is None \
else redband.chunks
res = zeros(shape, chunks=chunks)
return res.compute() if compute else res
else:
return zeros(shape)

idx = np.searchsorted(wvl_coord, wvl)
wvl1 = wvl_coord[idx - 1]
wvl2 = wvl_coord[idx]

fac = (wvl2 - wvl) / (wvl2 - wvl1)

raylwvl = fac * rayl[idx - 1, :, :, :] + (1 - fac) * rayl[idx, :, :, :]

import time
tic = time.time()

if WITH_CYTHON:
smin = [sunz_sec_coord[0], azid_coord[0], satz_sec_coord[0]]
smax = [sunz_sec_coord[-1], azid_coord[-1], satz_sec_coord[-1]]
orders = [
len(sunz_sec_coord), len(azid_coord), len(satz_sec_coord)]
f_3d_grid = raylwvl

minterp = MultilinearInterpolator(smin, smax, orders)
minterp.set_values(np.atleast_2d(f_3d_grid.ravel()))

interp_points2 = np.vstack((np.asarray(sunzsec).ravel(),
np.asarray(180 - azidiff).ravel(),
np.asarray(satzsec).ravel()))

ipn = minterp(interp_points2).reshape(shape)
smin = [sunz_sec_coord[0], azid_coord[0], satz_sec_coord[0]]
smax = [sunz_sec_coord[-1], azid_coord[-1], satz_sec_coord[-1]]
orders = [
len(sunz_sec_coord), len(azid_coord), len(satz_sec_coord)]
minterp = MultilinearInterpolator(smin, smax, orders)

f_3d_grid = raylwvl
minterp.set_values(atleast_2d(f_3d_grid.ravel()))

def _do_interp(minterp, sunzsec, azidiff, satzsec):
interp_points2 = np.vstack((sunzsec.ravel(),
180 - azidiff.ravel(),
satzsec.ravel()))
res = minterp(interp_points2)
return res.reshape(sunzsec.shape)

if HAVE_DASK:
ipn = map_blocks(_do_interp, minterp, sunzsec, azidiff,
satzsec, dtype=raylwvl.dtype,
chunks=azidiff.chunks)
else:
interp_points = np.dstack((np.asarray(sunzsec),
np.asarray(180. - azidiff),
np.asarray(satzsec)))

ipn = interpn((sunz_sec_coord, azid_coord, satz_sec_coord),
raylwvl[:, :, :], interp_points)
ipn = _do_interp(minterp, sunzsec, azidiff, satzsec)

LOG.debug("Time - Interpolation: {0:f}".format(time.time() - tic))

ipn *= 100
res = ipn
if redband is not None:
res = np.where(np.less(redband, 20.), res,
(1 - (redband - 20) / 80) * res)
res = where(redband < 20., res,
(1 - (redband - 20) / 80) * res)

return np.clip(res, 0, 100)
res = clip(res, 0, 100)
if compute:
res = res.compute()
return res


def get_reflectance_lut(filename):
Expand All @@ -228,27 +267,40 @@ def get_reflectance_lut(filename):
"""

with h5py.File(filename, 'r') as h5f:
tab = h5f['reflectance'][:]
wvl = h5f['wavelengths'][:]
azidiff = h5f['azimuth_difference'][:]
satellite_zenith_secant = h5f['satellite_zenith_secant'][:]
sun_zenith_secant = h5f['sun_zenith_secant'][:]
h5f = h5py.File(filename, 'r')

tab = h5f['reflectance']
wvl = h5f['wavelengths']
azidiff = h5f['azimuth_difference']
satellite_zenith_secant = h5f['satellite_zenith_secant']
sun_zenith_secant = h5f['sun_zenith_secant']

if HAVE_DASK:
tab = from_array(tab, chunks=(10, 10, 10, 10))
# wvl_coord is used in a lot of non-dask functions, keep in memory
wvl = from_array(wvl, chunks=(100,)).persist()
azidiff = from_array(azidiff, chunks=(1000,))
satellite_zenith_secant = from_array(satellite_zenith_secant,
chunks=(1000,))
sun_zenith_secant = from_array(sun_zenith_secant,
chunks=(1000,))
else:
# load all of the data we are going to use in to memory
tab = tab[:]
wvl = wvl[:]
azidiff = azidiff[:]
satellite_zenith_secant = satellite_zenith_secant[:]
sun_zenith_secant = sun_zenith_secant[:]
h5f.close()

return tab, wvl, azidiff, satellite_zenith_secant, sun_zenith_secant

if __name__ == "__main__":

this = Rayleigh('Suomi-NPP', 'viirs')
# SUNZ = np.arange(200000).reshape(400, 500) * 0.0004
# SATZ = np.arange(200000).reshape(400, 500) * 0.00025
# AZIDIFF = np.arange(200000).reshape(400, 500) * 0.0009
# rfl = this.get_reflectance(SUNZ, SATZ, AZIDIFF, 'M4')

SHAPE = (1000, 3000)
NDIM = SHAPE[0] * SHAPE[1]
SUNZ = np.ma.arange(
NDIM / 2, NDIM + NDIM / 2).reshape(SHAPE) * 60. / float(NDIM)
SATZ = np.ma.arange(NDIM).reshape(SHAPE) * 60. / float(NDIM)
AZIDIFF = np.ma.arange(NDIM).reshape(SHAPE) * 179.9 / float(NDIM)
rfl = this.get_reflectance(SUNZ, SATZ, AZIDIFF, 'M4')

# if __name__ == "__main__":
# SHAPE = (1000, 3000)
# NDIM = SHAPE[0] * SHAPE[1]
# SUNZ = np.ma.arange(
# NDIM / 2, NDIM + NDIM / 2).reshape(SHAPE) * 60. / float(NDIM)
# SATZ = np.ma.arange(NDIM).reshape(SHAPE) * 60. / float(NDIM)
# AZIDIFF = np.ma.arange(NDIM).reshape(SHAPE) * 179.9 / float(NDIM)
# rfl = this.get_reflectance(SUNZ, SATZ, AZIDIFF, 'M4')
3 changes: 2 additions & 1 deletion pyspectral/tests/test_raw_readers.py
Expand Up @@ -24,7 +24,6 @@

import sys
import mock
from pyspectral.aatsr_reader import AatsrRSR

if sys.version_info < (2, 7):
import unittest2 as unittest
Expand All @@ -44,6 +43,7 @@ class TestAatsrRsrReader(unittest.TestCase):
@mock.patch('pyspectral.config.get_config')
def setUp(self, get_config, _load, open_workbook):
"""Setup the natve MSG file handler for testing."""
from pyspectral.aatsr_reader import AatsrRSR
open_workbook.return_code = None
get_config.return_code = {}
_load.return_code = None
Expand All @@ -68,6 +68,7 @@ def suite():
mysuite.addTest(loader.loadTestsFromTestCase(TestAatsrRsrReader))
return mysuite


if __name__ == "__main__":
# So you can run tests from this module individually.
unittest.main()

0 comments on commit f6c2ded

Please sign in to comment.