Skip to content

Commit

Permalink
💥 Evaluate deepbedmap directly on xarray grid instead of temp file
Browse files Browse the repository at this point in the history
Oh the two weeks of background work just to come to this, pygmt.grdtrack and whatnot! Our srgan_train.get_deepbedmap_test_result function now evaluates the trained neural network model's predicted grid directly in memory instead of on a file! Lots of temporary file related boilerplate code removed, woohoo! There's been plenty of background work to get this in-memory grid to be georeferenced correctly (see #150, 7fd3345, 4a074d9), but it all ends up making the evaluation more accurate, and hopefully faster to run and cleaner to scale up.

Note also that the deepbedmap.get_image_and_bounds function has been refactored and renamed to get_image_with_bounds. The previous bounds was using xarray's centre-based pixel coordinates when it should have returned rasterio-style corner-based pixel coordinates used by data_prep.selective_tile. This is resolved using good ol' salem, and there's some extra code to handle getting the bounds for multiple inputs (instead of relying on salem's .extent function). The bounds themselves are now stored as an attribute inside the groundtruth xarray grid. As we are returning an xarray.DataArray grid instead of a numpy.array, we can use xarray's .plot() method to plot merged groundtruth grids (that are not on a regular grid) in a less funny way. There is also an 'indexers' parameter introduced to enable manually getting bounding boxes exactly divisible by 4, a quirky requirement of DeepBedMap...

The hardcoded bounding box view of Antarctica used in our deepbedmap.feature integration test has been updated in alignment with all the changes above (pixel offsets, manual crops, etc). New 'weiji14/deebedmap/model/test' tiles have been uploaded to quilt, and the new quilt hash to use is df0d28b24283c642f5dbe1a9baa22b605d8ae02ec1875c2edd067a614e99e5a4. Also patching 4a074d9 to fix data_prep.selective_tile's masking not handling fp16 NaN conversions as the Synthetic HighRes geotiff was acting up, and remove NaN checks after gapfilling as it trips up the big DeepBedMap tiller. What else? Phew!
  • Loading branch information
weiji14 committed Jun 13, 2019
1 parent 4a074d9 commit a0ab7c7
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 466 deletions.
10 changes: 7 additions & 3 deletions data_prep.ipynb
Expand Up @@ -1232,13 +1232,15 @@
" for da in daarray_list\n",
" ]\n",
" daarray_stack = dask.array.ma.masked_values(\n",
" x=dask.array.stack(seq=daarray_list), value=dataset.nodatavals\n",
" x=dask.array.stack(seq=daarray_list),\n",
" value=np.nan_to_num(-np.inf)\n",
" if dataset.dtype == np.int16 and dataset.nodatavals == (np.nan,)\n",
" else dataset.nodatavals,\n",
" )\n",
"\n",
" assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)\n",
" assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)\n",
" assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)\n",
" print(\"done!\")\n",
"\n",
" out_tiles = dask.array.ma.getdata(daarray_stack).compute().astype(dtype=np.float32)\n",
" mask = dask.array.ma.getmaskarray(daarray_stack).compute()\n",
Expand All @@ -1249,6 +1251,7 @@
"\n",
" # Replace pixels from another raster if available, else raise error\n",
" if gapfill_raster_filepath is not None:\n",
" print(f\"gapfilling ... \", end=\"\")\n",
" with xr.open_rasterio(gapfill_raster_filepath, chunks={}) as dataset2:\n",
" daarray_list2 = [\n",
" dataset2.interp_like(daarray_list[idx].squeeze(), method=\"linear\")\n",
Expand All @@ -1266,14 +1269,15 @@
" for i, array2 in enumerate(fill_tiles):\n",
" idx = nan_grid_indexes[i]\n",
" np.copyto(dst=out_tiles[idx], src=array2, where=mask[idx])\n",
" assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill\n",
" # assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill\n",
"\n",
" else:\n",
" for i in nan_grid_indexes:\n",
" daarray_list[i].plot()\n",
" plt.show()\n",
" print(f\"WARN: Tiles have missing data, try pass in gapfill_raster_filepath\")\n",
"\n",
" print(\"done!\")\n",
" return out_tiles"
]
},
Expand Down
10 changes: 7 additions & 3 deletions data_prep.py
Expand Up @@ -642,13 +642,15 @@ def selective_tile(
for da in daarray_list
]
daarray_stack = dask.array.ma.masked_values(
x=dask.array.stack(seq=daarray_list), value=dataset.nodatavals
x=dask.array.stack(seq=daarray_list),
value=np.nan_to_num(-np.inf)
if dataset.dtype == np.int16 and dataset.nodatavals == (np.nan,)
else dataset.nodatavals,
)

assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)
assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)
assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)
print("done!")

out_tiles = dask.array.ma.getdata(daarray_stack).compute().astype(dtype=np.float32)
mask = dask.array.ma.getmaskarray(daarray_stack).compute()
Expand All @@ -659,6 +661,7 @@ def selective_tile(

# Replace pixels from another raster if available, else raise error
if gapfill_raster_filepath is not None:
print(f"gapfilling ... ", end="")
with xr.open_rasterio(gapfill_raster_filepath, chunks={}) as dataset2:
daarray_list2 = [
dataset2.interp_like(daarray_list[idx].squeeze(), method="linear")
Expand All @@ -676,14 +679,15 @@ def selective_tile(
for i, array2 in enumerate(fill_tiles):
idx = nan_grid_indexes[i]
np.copyto(dst=out_tiles[idx], src=array2, where=mask[idx])
assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill
# assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill

else:
for i in nan_grid_indexes:
daarray_list[i].plot()
plt.show()
print(f"WARN: Tiles have missing data, try pass in gapfill_raster_filepath")

print("done!")
return out_tiles


Expand Down
350 changes: 188 additions & 162 deletions deepbedmap.ipynb

Large diffs are not rendered by default.

87 changes: 56 additions & 31 deletions deepbedmap.py
Expand Up @@ -26,6 +26,9 @@

os.environ["CUDA_VISIBLE_DEVICES"] = ""

import xarray as xr
import salem

import comet_ml
import cupy
import matplotlib
Expand All @@ -37,7 +40,6 @@
import quilt
import rasterio
import skimage
import xarray as xr

import chainer

Expand All @@ -49,40 +51,60 @@
# ## Get bounding box of area we want to predict on

# %%
def get_image_and_bounds(filepaths: list) -> (np.ndarray, rasterio.coords.BoundingBox):
def get_image_with_bounds(filepaths: list, indexers: dict = None) -> xr.DataArray:
"""
Retrieve raster image in numpy array format and
geographic bounds as (xmin, ymin, xmax, ymax)
Retrieve raster image in xarray.DataArray format patched
with projected coordinate bounds as (xmin, ymin, xmax, ymax)
Note that if more than one filepath is passed in,
the output groundtruth image array will not be valid
(see https://github.com/pydata/xarray/issues/2159),
but the window_bound extents will be correct
"""
if len(filepaths) > 1:
print("WARN: using multiple inputs, output groundtruth image will look funny")

with xr.open_mfdataset(paths=filepaths, concat_dim=None) as data:
groundtruth = data.z.to_masked_array()
groundtruth = np.flipud(groundtruth) # flip on y-axis...
groundtruth = np.expand_dims(
np.expand_dims(groundtruth, axis=0), axis=0
) # add extra dimensions (batch and channel)
assert groundtruth.shape[0:2] == (1, 1) # check that shape is like (1, 1, h, w)

xmin, xmax = float(data.x.min()), float(data.x.max())
ymin, ymax = float(data.y.min()), float(data.y.max())

window_bound = rasterio.coords.BoundingBox(
left=xmin, bottom=ymin, right=xmax, top=ymax

with xr.open_mfdataset(paths=filepaths, concat_dim=None) as dataset:
# Retrieve dataarray from NetCDF datasets
dataarray = dataset.z.isel(indexers=indexers)

# Patch projection information into xarray grid
dataarray.attrs["pyproj_srs"] = "epsg:3031"
sgrid = dataarray.salem.grid.corner_grid
assert sgrid.origin == "lower-left" # should be "lower-left", not "upper-left"

# Patch bounding box extent into xarray grid
if len(filepaths) == 1:
left, right, bottom, top = sgrid.extent
elif len(filepaths) > 1:
print("WARN: using multiple inputs, output groundtruth image may look funny")
x_offset, y_offset = sgrid.dx / 2, sgrid.dy / 2
left, right = (
float(dataarray.x[0] - x_offset),
float(dataarray.x[-1] + x_offset),
)
return groundtruth, window_bound
assert sgrid.x0 == left
bottom, top = (
float(dataarray.y[0] - y_offset),
float(dataarray.y[-1] + y_offset),
)
assert sgrid.y0 == bottom # dataarray.y.min()-y_offset

# check that y-axis and x-axis lengths are divisible by 4
try:
shape = int((top - bottom) / sgrid.dy), int((right - left) / sgrid.dx)
assert all(i % 4 == 0 for i in shape)
except AssertionError:
print(f"WARN: Image shape {shape} should be divisible by 4 for DeepBedMap")
finally:
dataarray.attrs["bounds"] = [left, bottom, right, top]

return dataarray


# %%
test_filepaths = ["highres/2007tx", "highres/2010tr", "highres/istarxx"]
groundtruth, window_bound = get_image_and_bounds(
filepaths=[f"{t}.nc" for t in test_filepaths]
groundtruth = get_image_with_bounds(
filepaths=[f"{t}.nc" for t in test_filepaths],
indexers={"y": slice(0, -1), "x": slice(0, -1)},
)

# %% [markdown]
Expand All @@ -91,7 +113,7 @@ def get_image_and_bounds(filepaths: list) -> (np.ndarray, rasterio.coords.Boundi
# %%
def get_deepbedmap_model_inputs(
window_bound: rasterio.coords.BoundingBox, padding=1000
) -> typing.Dict[str, np.ndarray]:
) -> (np.ndarray, np.ndarray, np.ndarray):
"""
Outputs one large tile for each of
BEDMAP2, REMA and MEASURES Ice Flow Velocity
Expand All @@ -103,18 +125,20 @@ def get_deepbedmap_model_inputs(
X_tile = data_prep.selective_tile(
filepath="lowres/bedmap2_bed.tif",
window_bounds=[[*window_bound]],
# out_shape=None, # 1000m spatial resolution
padding=padding,
)
W2_tile = data_prep.selective_tile(
filepath="misc/MEaSUREs_IceFlowSpeed_450m.tif",
window_bounds=[[*window_bound]],
out_shape=(2 * X_tile.shape[2], 2 * X_tile.shape[3]),
out_shape=(2 * X_tile.shape[2], 2 * X_tile.shape[3]), # 500m spatial resolution
padding=padding,
gapfill_raster_filepath="misc/lisa750_2013182_2017120_0000_0400_vv_v1_myr.tif",
)
W1_tile = data_prep.selective_tile(
filepath="misc/REMA_100m_dem.tif",
window_bounds=[[*window_bound]],
# out_shape=(5 * W2_tile.shape[2], 5 * W2_tile.shape[3]), # 100m spatial resolution
padding=padding,
gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif",
)
Expand Down Expand Up @@ -145,7 +169,8 @@ def plot_3d_view(


# %%
X_tile, W1_tile, W2_tile = get_deepbedmap_model_inputs(window_bound=window_bound)
X_tile, W1_tile, W2_tile = get_deepbedmap_model_inputs(window_bound=groundtruth.bounds)
print(X_tile.shape, W1_tile.shape, W2_tile.shape)

# Build quilt package for datasets covering our test region
reupload = False
Expand Down Expand Up @@ -253,7 +278,7 @@ def load_trained_model(

# %%
S_tile = data_prep.selective_tile(
filepath="model/hres.tif", window_bounds=[[*window_bound]]
filepath="model/hres.tif", window_bounds=[[*groundtruth.bounds]]
)

# %% [markdown]
Expand All @@ -267,7 +292,7 @@ def load_trained_model(
axarr[0, 1].set_title("Super Resolution Generative Adversarial Network prediction")
axarr[0, 2].imshow(S_tile[0, 0, :, :], cmap="BrBG")
axarr[0, 2].set_title("Synthetic High Resolution Grid")
axarr[0, 3].imshow(groundtruth[0, 0, :, :], cmap="BrBG")
groundtruth.plot(ax=axarr[0, 3], cmap="BrBG")
axarr[0, 3].set_title("Groundtruth grids")
plt.show()

Expand Down Expand Up @@ -354,13 +379,13 @@ def save_array_to_grid(
# %%
# Save BEDMAP3 to GeoTiff and NetCDF format
save_array_to_grid(
window_bound=window_bound, array=Y_hat, outfilepath="model/deepbedmap3"
window_bound=groundtruth.bounds, array=Y_hat, outfilepath="model/deepbedmap3"
)

# %%
# Save Bicubic Resampled BEDMAP2 to GeoTiff and NetCDF format
save_array_to_grid(
window_bound=window_bound, array=cubicbedmap2, outfilepath="model/cubicbedmap"
window_bound=groundtruth.bounds, array=cubicbedmap2, outfilepath="model/cubicbedmap"
)

# %%
Expand All @@ -374,7 +399,7 @@ def save_array_to_grid(
preserve_range=True,
)
save_array_to_grid(
window_bound=window_bound,
window_bound=groundtruth.bounds,
array=np.expand_dims(np.expand_dims(synthetic, axis=0), axis=0),
outfilepath="model/synthetichr",
)
Expand Down
2 changes: 1 addition & 1 deletion features/deepbedmap.feature
Expand Up @@ -13,4 +13,4 @@ Feature: DeepBedMap

Examples: Bounding box views of Antarctica
| bounding_box |
| -1593714.328,-164173.7848,-1575464.328,-97923.7848 |
[ -1593589.328,-164048.7848,-1575589.328,-98048.7848 ]
2 changes: 1 addition & 1 deletion features/steps/test_deepbedmap.py
Expand Up @@ -15,7 +15,7 @@ def window_view_of_Antarctica(context, bounding_box):
def get_model_input_raster_images(context):
# TODO refactor code below that is hardcoded for a particular test region
if context.window_bound == rasterio.coords.BoundingBox(
left=-1_593_714.328, bottom=-164_173.7848, right=-1_575_464.328, top=-97923.7848
left=-1_593_589.328, bottom=-164_048.7848, right=-1_575_589.328, top=-98048.7848
):
quilt.install(package="weiji14/deepbedmap/model/test", force=True)
pkg = quilt.load(pkginfo="weiji14/deepbedmap/model/test")
Expand Down

0 comments on commit a0ab7c7

Please sign in to comment.