Skip to content

Commit

Permalink
♻️ Precise selective tiling and interpolation using xarray slices
Browse files Browse the repository at this point in the history
Towards more fine-grained cropping of our image tiles! Basically have data_prep.selective tile do the crop using exact geographic coordinate slice ranges instead of having to convert (sometimes imprecisely) to image-based coordinates. Uses xarray's subset 'sel'(ection) method which does away with the mess that is rasterio.windows and affine transformations. REMA tiles doesn't seem to require gapfilling anymore so we've temporarily disabled gapfilling (raise NotImplementedError) until it is needed for getting tiles for the whole of Antarctica again.

Also using a nicer interpolation method in data_prep.selective_tile, especially relevant for W2_data aka the MEASURES Surface Ice Velocity which is resampled from 450m to 500m (since a8863e4). Still resampling billinearly, but interpolation at the cropped tile's edges take into account pixels beyond the border if available. I've actually inspected these new Ice Velocity tiles manually and they look awesome! Might help with the strange high-level checkerboard artifacts. Side effect is that interpolation runs slowly (mitigated somewhat by using dask), until we can vectorize the whole function properly.
  • Loading branch information
weiji14 committed Jun 12, 2019
1 parent c38f74a commit 51a504f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 128 deletions.
121 changes: 51 additions & 70 deletions data_prep.ipynb
Expand Up @@ -54,6 +54,8 @@
"import xarray as xr\n",
"import salem\n",
"\n",
"import dask\n",
"import dask.diagnostics\n",
"import geopandas as gpd\n",
"import pygmt as gmt\n",
"import IPython.display\n",
Expand Down Expand Up @@ -1200,69 +1202,50 @@
" [0., 0.]]]], dtype=float32)\n",
" >>> os.remove(\"/tmp/tmp_st.nc\")\n",
" \"\"\"\n",
" array_list = []\n",
"\n",
" with rasterio.open(filepath) as dataset:\n",
" # Convert list of bounding box tuples to nice rasterio.coords.BoundingBox class\n",
" window_bounds = [\n",
" rasterio.coords.BoundingBox(\n",
" left=x0 - padding, bottom=y0 - padding, right=x1 + padding, top=y1 + padding\n",
" )\n",
" for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax\n",
" ]\n",
"\n",
" with xr.open_rasterio(\n",
" filepath, chunks=None if out_shape is None else {}, cache=False\n",
" ) as dataset:\n",
" print(f\"Tiling: {filepath} ... \", end=\"\")\n",
" for window_bound in window_bounds:\n",
"\n",
" if padding > 0:\n",
" window_bound = (\n",
" window_bound[0] - padding, # minx\n",
" window_bound[1] - padding, # miny\n",
" window_bound[2] + padding, # maxx\n",
" window_bound[3] + padding, # maxy\n",
" # Subset dataset according to window bound (wb)\n",
" daarray_list = [\n",
" dataset.sel(y=slice(wb.top, wb.bottom), x=slice(wb.left, wb.right))\n",
" for wb in window_bounds\n",
" ]\n",
" # Bilinear interpolate to new shape if out_shape is set\n",
" if out_shape is not None:\n",
" daarray_list = [\n",
" dataset.interp(\n",
" y=np.linspace(da.y[0], da.y[-1], num=out_shape[0]),\n",
" x=np.linspace(da.x[0], da.x[-1], num=out_shape[1]),\n",
" method=\"linear\",\n",
" )\n",
" for da in daarray_list\n",
" ]\n",
" daarray_stack = dask.array.stack(seq=daarray_list).astype(dtype=np.float32)\n",
"\n",
" window = rasterio.windows.from_bounds(\n",
" *window_bound, transform=dataset.transform, precision=None\n",
" ).round_offsets()\n",
"\n",
" # Read the raster according to the crop window\n",
" array = dataset.read(\n",
" indexes=list(range(1, dataset.count + 1)),\n",
" masked=True,\n",
" window=window,\n",
" out_shape=out_shape,\n",
" )\n",
" assert array.ndim == 3 # check that we have shape like (1, height, width)\n",
" assert array.shape[0] == 1 # channel-first (assuming only 1 channel)\n",
" assert not 0 in array.shape # ensure no empty dimensions (invalid window)\n",
"\n",
" try:\n",
" assert not array.mask.any() # check that there are no NAN values\n",
" except AssertionError:\n",
" # Replace pixels from another raster if available, else raise error\n",
" if gapfill_raster_filepath is not None:\n",
" with rasterio.open(gapfill_raster_filepath) as dataset2:\n",
" window2 = rasterio.windows.from_bounds(\n",
" *window_bound, transform=dataset2.transform, precision=None\n",
" ).round_offsets()\n",
"\n",
" array2 = dataset2.read(\n",
" indexes=list(range(1, dataset2.count + 1)),\n",
" masked=True,\n",
" window=window2,\n",
" out_shape=array.shape[1:],\n",
" )\n",
"\n",
" np.copyto(\n",
" dst=array, src=array2, where=array.mask\n",
" ) # fill in gaps where mask is True\n",
"\n",
" # assert not array.mask.any() # ensure no NAN values after gapfill\n",
" else:\n",
" plt.imshow(array.data[0, :, :])\n",
" plt.show()\n",
" print(\n",
" f\"WARN: Tile has missing data, try passing in gapfill_raster_filepath\"\n",
" )\n",
"\n",
" # assert array.shape[1] == array.shape[2] # check that height==width\n",
" array_list.append(array.data.astype(dtype=np.float32))\n",
" assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)\n",
" assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)\n",
" assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)\n",
" print(\"done!\")\n",
"\n",
" return np.stack(arrays=array_list)"
" with dask.diagnostics.ProgressBar(minimum=5.0):\n",
" try:\n",
" out_tiles = daarray_stack.compute()\n",
" assert not np.isnan(out_tiles).any() # check that there are no NAN values\n",
" except AssertionError:\n",
" raise NotImplementedError(\"gapfilling on dask xarray not yet implemented\")\n",
" finally:\n",
" return out_tiles"
]
},
{
Expand Down Expand Up @@ -1370,7 +1353,7 @@
" filepath=\"misc/REMA_100m_dem.tif\",\n",
" window_bounds=window_bounds_concat,\n",
" padding=1000,\n",
" gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n",
" # gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n",
")\n",
"print(rema.shape, rema.dtype)"
]
Expand Down Expand Up @@ -1407,7 +1390,8 @@
"output_type": "stream",
"text": [
"Tiling: misc/MEaSUREs_IceFlowSpeed_450m.tif ... done!\n",
"(2347, 1, 20, 20) float32\n"
"[########################################] | 100% Completed | 26.4s\n",
"(2347, 1, 20, 20) float64\n"
]
}
],
Expand Down Expand Up @@ -1501,7 +1485,7 @@
"name": "stdin",
"output_type": "stream",
"text": [
"Enter the code from the webpage: eyJjb2RlIjogImVmOTRiMzMzLTZkNDItNDJkYi1hM2Y1LTQ4NGNmZjc4OTIzOSIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n"
"Enter the code from the webpage: eyJjb2RlIjogIjA5YTRiM2M5LTU1ODAtNGE1Ni1iZDkzLWRkYzg4NzVjZWE1MSIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n"
]
}
],
Expand Down Expand Up @@ -1566,39 +1550,36 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Uploading 13 fragments (6739537494 bytes)...\n"
"Uploading 13 fragments (6743292694 bytes)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████| 6.73G/6.74G [00:01<00:00, 143MB/s]"
" 98%|█████████| 6.64G/6.74G [00:01<00:00, 3.87GB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fragment 28e2ca7656d61b0bc7f8f8c1db41914023e0cab1634e0ee645f38a87d894b416 already uploaded; skipping.\n",
"Fragment 80c9fa41ccc69be1d2cd4a367d56168321d1079e7260a1996089810db25172f6 already uploaded; skipping.\n",
"Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n",
"Fragment 28e2ca7656d61b0bc7f8f8c1db41914023e0cab1634e0ee645f38a87d894b416 already uploaded; skipping.\n",
"Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n",
"Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n",
"Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n",
"Fragment ca9c41a8dd56097e40865d2e65c65d299c22fc17608ddb6c604c532a69936307 already uploaded; skipping.\n",
"Fragment 704bb2fafcc9a6411047f799030dde3b4c2fb14de2e8d1eccfe651dcc6a455bf already uploaded; skipping.\n",
"Fragment bb9e1e7a62187671e58009533d2e930265f6c0827925216d354b984e2d506996 already uploaded; skipping.\n",
"Fragment c665815f043b87cfe94d51caabd1b57d8f6f6773d632503de6db0725f20d391c already uploaded; skipping.\n",
"Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n",
"Fragment f750893861a1a268c8ffe0ba7db36c933223bbf5fcbb786ecef3f052b20f9b8a already uploaded; skipping.\n",
"Fragment c7d98ad4258130d8cdea7ec6c9fbb33a868e64d4a14a57955f759ba3d35180c4 already uploaded; skipping.\n",
"Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n",
"Fragment fae1c9c2308c944488a9bc4703518395f3056cbeb55fd11f0f114282eb8cdf32 already uploaded; skipping.\n"
"Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 6.74G/6.74G [00:03<00:00, 2.11GB/s]\n"
"100%|██████████| 6.74G/6.74G [00:11<00:00, 607MB/s] \n"
]
},
{
Expand Down
99 changes: 41 additions & 58 deletions data_prep.py
Expand Up @@ -38,6 +38,8 @@
import xarray as xr
import salem

import dask
import dask.diagnostics
import geopandas as gpd
import pygmt as gmt
import IPython.display
Expand Down Expand Up @@ -610,69 +612,50 @@ def selective_tile(
[0., 0.]]]], dtype=float32)
>>> os.remove("/tmp/tmp_st.nc")
"""
array_list = []

with rasterio.open(filepath) as dataset:
print(f"Tiling: {filepath} ... ", end="")
for window_bound in window_bounds:

if padding > 0:
window_bound = (
window_bound[0] - padding, # minx
window_bound[1] - padding, # miny
window_bound[2] + padding, # maxx
window_bound[3] + padding, # maxy
)
# Convert list of bounding box tuples to nice rasterio.coords.BoundingBox class
window_bounds = [
rasterio.coords.BoundingBox(
left=x0 - padding, bottom=y0 - padding, right=x1 + padding, top=y1 + padding
)
for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax
]

window = rasterio.windows.from_bounds(
*window_bound, transform=dataset.transform, precision=None
).round_offsets()
with xr.open_rasterio(
filepath, chunks=None if out_shape is None else {}, cache=False
) as dataset:
print(f"Tiling: {filepath} ... ", end="")

# Read the raster according to the crop window
array = dataset.read(
indexes=list(range(1, dataset.count + 1)),
masked=True,
window=window,
out_shape=out_shape,
)
assert array.ndim == 3 # check that we have shape like (1, height, width)
assert array.shape[0] == 1 # channel-first (assuming only 1 channel)
assert not 0 in array.shape # ensure no empty dimensions (invalid window)
# Subset dataset according to window bound (wb)
daarray_list = [
dataset.sel(y=slice(wb.top, wb.bottom), x=slice(wb.left, wb.right))
for wb in window_bounds
]
# Bilinear interpolate to new shape if out_shape is set
if out_shape is not None:
daarray_list = [
dataset.interp(
y=np.linspace(da.y[0], da.y[-1], num=out_shape[0]),
x=np.linspace(da.x[0], da.x[-1], num=out_shape[1]),
method="linear",
)
for da in daarray_list
]
daarray_stack = dask.array.stack(seq=daarray_list).astype(dtype=np.float32)

try:
assert not array.mask.any() # check that there are no NAN values
except AssertionError:
# Replace pixels from another raster if available, else raise error
if gapfill_raster_filepath is not None:
with rasterio.open(gapfill_raster_filepath) as dataset2:
window2 = rasterio.windows.from_bounds(
*window_bound, transform=dataset2.transform, precision=None
).round_offsets()

array2 = dataset2.read(
indexes=list(range(1, dataset2.count + 1)),
masked=True,
window=window2,
out_shape=array.shape[1:],
)

np.copyto(
dst=array, src=array2, where=array.mask
) # fill in gaps where mask is True

# assert not array.mask.any() # ensure no NAN values after gapfill
else:
plt.imshow(array.data[0, :, :])
plt.show()
print(
f"WARN: Tile has missing data, try passing in gapfill_raster_filepath"
)

# assert array.shape[1] == array.shape[2] # check that height==width
array_list.append(array.data.astype(dtype=np.float32))
assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)
assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)
assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)
print("done!")

return np.stack(arrays=array_list)
with dask.diagnostics.ProgressBar(minimum=5.0):
try:
out_tiles = daarray_stack.compute()
assert not np.isnan(out_tiles).any() # check that there are no NAN values
except AssertionError:
raise NotImplementedError("gapfilling on dask xarray not yet implemented")
finally:
return out_tiles


# %%
Expand Down Expand Up @@ -712,7 +695,7 @@ def selective_tile(
filepath="misc/REMA_100m_dem.tif",
window_bounds=window_bounds_concat,
padding=1000,
gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif",
# gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif",
)
print(rema.shape, rema.dtype)

Expand Down

0 comments on commit 51a504f

Please sign in to comment.