Skip to content

Commit

Permalink
added some logic to deal with rasterio objects in addition to filepat…
Browse files Browse the repository at this point in the history
…hs (#2589)

* added some logic to deal with rasterio objects in addition to filepath strings

* added no network test, pep8 compliance, whatsnew.rst

* removed subclass, added to base RasterioArrayWrapper

* upped rasterio test version to > 1

* specified rasterio version should be greater than 1
  • Loading branch information
scottyhq authored and shoyer committed Dec 23, 2018
1 parent ce52341 commit 9352b3c
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ matrix:
- env: CONDA_ENV=py36-bottleneck-dev
- env: CONDA_ENV=py36-condaforge-rc
- env: CONDA_ENV=py36-pynio-dev
- env: CONDA_ENV=py36-rasterio-0.36
- env: CONDA_ENV=py36-rasterio
- env: CONDA_ENV=py36-zarr-dev
- env: CONDA_ENV=docs
- env: CONDA_ENV=py36-hypothesis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- scipy
- seaborn
- toolz
- rasterio=0.36.0
- rasterio>=1.0
- bottleneck
- pip:
- coveralls
Expand Down
2 changes: 1 addition & 1 deletion doc/installing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ For netCDF and IO
for accessing CAMx, GEOS-Chem (bpch), NOAA ARL files, ICARTT files
(ffi1001) and many other.
- `rasterio <https://github.com/mapbox/rasterio>`__: for reading GeoTiffs and
other gridded raster datasets.
other gridded raster datasets. (version 1.0 or later)
- `iris <https://github.com/scitools/iris>`__: for conversion to and from iris'
Cube objects
- `cfgrib <https://github.com/ecmwf/cfgrib>`__: for reading GRIB files via the
Expand Down
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ v0.11.1 (unreleased)
Breaking changes
~~~~~~~~~~~~~~~~

- Minimum rasterio version increased from 0.36 to 1.0 (for ``open_rasterio``)
- Time bounds variables are now also decoded according to CF conventions
(:issue:`2565`). The previous behavior was to decode them only if they
had specific time attributes, now these attributes are copied
Expand All @@ -49,6 +50,10 @@ Enhancements
- :py:class:`CFTimeIndex` uses slicing for string indexing when possible (like
:py:class:`pandas.DatetimeIndex`), which avoids unnecessary copies.
By `Stephan Hoyer <https://github.com/shoyer>`_
- Enable passing ``rasterio.io.DatasetReader`` or ``rasterio.vrt.WarpedVRT`` to
``open_rasterio`` instead of file path string. Allows for in-memory
reprojection, see (:issue:`2588`).
By `Scott Henderson <https://github.com/scottyhq>`_.
- Like :py:class:`pandas.DatetimeIndex`, :py:class:`CFTimeIndex` now supports
"dayofyear" and "dayofweek" accessors (:issue:`2597`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
Expand Down
35 changes: 25 additions & 10 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings
from collections import OrderedDict
from distutils.version import LooseVersion

import numpy as np

from .. import DataArray
Expand All @@ -23,13 +22,14 @@

class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""

def __init__(self, manager):
def __init__(self, manager, vrt_params=None):
from rasterio.vrt import WarpedVRT
self.manager = manager

# cannot save riods as an attribute: this would break pickleability
riods = manager.acquire()

riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params)
self.vrt_params = vrt_params
self._shape = (riods.count, riods.height, riods.width)

dtypes = riods.dtypes
Expand Down Expand Up @@ -103,6 +103,7 @@ def _get_indexer(self, key):
return band_key, tuple(window), tuple(squeeze_axis), tuple(np_inds)

def _getitem(self, key):
from rasterio.vrt import WarpedVRT
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

if not band_key or any(start == stop for (start, stop) in window):
Expand All @@ -112,6 +113,7 @@ def _getitem(self, key):
out = np.zeros(shape, dtype=self.dtype)
else:
riods = self.manager.acquire()
riods = riods if self.vrt_params is None else WarpedVRT(riods,**self.vrt_params)
out = riods.read(band_key, window=window)

if squeeze_axis:
Expand Down Expand Up @@ -176,8 +178,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
Parameters
----------
filename : str
Path to the file to open.
filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT
Path to the file to open. Or already open rasterio dataset.
parse_coordinates : bool, optional
Whether to parse the x and y coordinates out of the file's
``transform`` attribute or not. The default is to automatically
Expand All @@ -204,11 +206,24 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
data : DataArray
The newly created DataArray.
"""

import rasterio
from rasterio.vrt import WarpedVRT
vrt_params = None
if isinstance(filename, rasterio.io.DatasetReader):
filename = filename.name
elif isinstance(filename, rasterio.vrt.WarpedVRT):
vrt = filename
filename = vrt.src_dataset.name
vrt_params = dict(crs=vrt.crs.to_string(),
resampling=vrt.resampling,
src_nodata=vrt.src_nodata,
dst_nodata=vrt.dst_nodata,
tolerance=vrt.tolerance,
warp_extras=vrt.warp_extras)

manager = CachingFileManager(rasterio.open, filename, mode='r')
riods = manager.acquire()
riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params)

if cache is None:
cache = chunks is None
Expand Down Expand Up @@ -282,13 +297,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
for k, v in meta.items():
# Add values as coordinates if they match the band count,
# as attributes otherwise
if (isinstance(v, (list, np.ndarray)) and
len(v) == riods.count):
if (isinstance(v, (list, np.ndarray))
and len(v) == riods.count):
coords[k] = ('band', np.asarray(v))
else:
attrs[k] = v

data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager))
data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager, vrt_params))

# this lets you write arrays loaded with rasterio
data = indexing.CopyOnWriteArray(data)
Expand Down
113 changes: 90 additions & 23 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def check_dtypes_roundtripped(self, expected, actual):
actual_dtype = actual.variables[k].dtype
# TODO: check expected behavior for string dtypes more carefully
string_kinds = {'O', 'S', 'U'}
assert (expected_dtype == actual_dtype or
(expected_dtype.kind in string_kinds and
actual_dtype.kind in string_kinds))
assert (expected_dtype == actual_dtype
or (expected_dtype.kind in string_kinds and
actual_dtype.kind in string_kinds))

def test_roundtrip_test_data(self):
expected = create_test_data()
Expand Down Expand Up @@ -410,17 +410,17 @@ def test_roundtrip_cftime_datetime_data(self):
with self.roundtrip(expected, save_kwargs=kwds) as actual:
abs_diff = abs(actual.t.values - expected_decoded_t)
assert (abs_diff <= np.timedelta64(1, 's')).all()
assert (actual.t.encoding['units'] ==
'days since 0001-01-01 00:00:00.000000')
assert (actual.t.encoding['calendar'] ==
expected_calendar)
assert (actual.t.encoding['units']
== 'days since 0001-01-01 00:00:00.000000')
assert (actual.t.encoding['calendar']
== expected_calendar)

abs_diff = abs(actual.t0.values - expected_decoded_t0)
assert (abs_diff <= np.timedelta64(1, 's')).all()
assert (actual.t0.encoding['units'] ==
'days since 0001-01-01')
assert (actual.t.encoding['calendar'] ==
expected_calendar)
assert (actual.t0.encoding['units']
== 'days since 0001-01-01')
assert (actual.t.encoding['calendar']
== expected_calendar)

def test_roundtrip_timedelta_data(self):
time_deltas = pd.to_timedelta(['1h', '2h', 'NaT'])
Expand Down Expand Up @@ -668,24 +668,24 @@ def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn):

with self.roundtrip(decoded) as actual:
for k in decoded.variables:
assert (decoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (decoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(decoded, actual, decode_bytes=False)

with self.roundtrip(decoded,
open_kwargs=dict(decode_cf=False)) as actual:
# TODO: this assumes that all roundtrips will first
# encode. Is that something we want to test for?
for k in encoded.variables:
assert (encoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (encoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(encoded, actual, decode_bytes=False)

with self.roundtrip(encoded,
open_kwargs=dict(decode_cf=False)) as actual:
for k in encoded.variables:
assert (encoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (encoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(encoded, actual, decode_bytes=False)

# make sure roundtrip encoding didn't change the
Expand Down Expand Up @@ -2621,8 +2621,8 @@ def myatts(**attrs):
'ULOD_FLAG': '-7777', 'ULOD_VALUE': 'N/A',
'LLOD_FLAG': '-8888',
'LLOD_VALUE': ('N/A, N/A, N/A, N/A, 0.025'),
'OTHER_COMMENTS': ('www-air.larc.nasa.gov/missions/etc/' +
'IcarttDataFormat.htm'),
'OTHER_COMMENTS': ('www-air.larc.nasa.gov/missions/etc/'
+ 'IcarttDataFormat.htm'),
'REVISION': 'R0',
'R0': 'No comments for this revision.',
'TFLAG': 'Start_UTC'
Expand Down Expand Up @@ -2711,8 +2711,8 @@ def test_uamiv_format_read(self):
expected = xr.Variable(('TSTEP',), data,
dict(bounds='time_bounds',
long_name=('synthesized time coordinate ' +
'from SDATE, STIME, STEP ' +
'global attributes')))
'from SDATE, STIME, STEP '
+ 'global attributes')))
actual = camxfile.variables['time']
assert_allclose(expected, actual)
camxfile.close()
Expand Down Expand Up @@ -2741,8 +2741,8 @@ def test_uamiv_format_mfread(self):
data = np.concatenate([data1] * 2, axis=0)
attrs = dict(bounds='time_bounds',
long_name=('synthesized time coordinate ' +
'from SDATE, STIME, STEP ' +
'global attributes'))
'from SDATE, STIME, STEP '
+ 'global attributes'))
expected = xr.Variable(('TSTEP',), data, attrs)
actual = camxfile.variables['time']
assert_allclose(expected, actual)
Expand Down Expand Up @@ -3158,6 +3158,73 @@ def test_http_url(self):
import dask.array as da
assert isinstance(actual.data, da.Array)

def test_rasterio_environment(self):
import rasterio
with create_tmp_geotiff() as (tmp_file, expected):
# Should fail with error since suffix not allowed
with pytest.raises(Exception):
with rasterio.Env(GDAL_SKIP='GTiff'):
with xr.open_rasterio(tmp_file) as actual:
assert_allclose(actual, expected)

def test_rasterio_vrt(self):
import rasterio
# tmp_file default crs is UTM: CRS({'init': 'epsg:32618'}
with create_tmp_geotiff() as (tmp_file, expected):
with rasterio.open(tmp_file) as src:
with rasterio.vrt.WarpedVRT(src, crs='epsg:4326') as vrt:
expected_shape = (vrt.width, vrt.height)
expected_crs = vrt.crs
print(expected_crs)
expected_res = vrt.res
# Value of single pixel in center of image
lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2)
expected_val = next(vrt.sample([(lon, lat)]))
with xr.open_rasterio(vrt) as da:
actual_shape = (da.sizes['x'], da.sizes['y'])
actual_crs = da.crs
print(actual_crs)
actual_res = da.res
actual_val = da.sel(dict(x=lon, y=lat),
method='nearest').data

assert actual_crs == expected_crs
assert actual_res == expected_res
assert actual_shape == expected_shape
assert expected_val.all() == actual_val.all()

@network
def test_rasterio_vrt_network(self):
import rasterio

url = 'https://storage.googleapis.com/\
gcp-public-data-landsat/LC08/01/047/027/\
LC08_L1TP_047027_20130421_20170310_01_T1/\
LC08_L1TP_047027_20130421_20170310_01_T1_B4.TIF'
env = rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR',
CPL_VSIL_CURL_USE_HEAD=False,
CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF')
with env:
with rasterio.open(url) as src:
with rasterio.vrt.WarpedVRT(src, crs='epsg:4326') as vrt:
expected_shape = (vrt.width, vrt.height)
expected_crs = vrt.crs
expected_res = vrt.res
# Value of single pixel in center of image
lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2)
expected_val = next(vrt.sample([(lon, lat)]))
with xr.open_rasterio(vrt) as da:
actual_shape = (da.sizes['x'], da.sizes['y'])
actual_crs = da.crs
actual_res = da.res
actual_val = da.sel(dict(x=lon, y=lat),
method='nearest').data

assert_equal(actual_shape, expected_shape)
assert_equal(actual_crs, expected_crs)
assert_equal(actual_res, expected_res)
assert_equal(expected_val, actual_val)


class TestEncodingInvalid(object):

Expand Down

0 comments on commit 9352b3c

Please sign in to comment.