# Land Use / Land Cover

This notebook demonstrates applying a land use classification model to NAIP data.

It does a batch prediction on many scenese that have been merged together, using a cluster of machines, each with a GPU.

In [None]:
import numpy as np
import xarray as xr
import rioxarray
from dask_cuda import LocalCUDACluster
from distributed import Client, wait
import adlfs
import utils
import torch
import pandas as pd
import dask
import itertools
from dask_gateway import Gateway
import rasterio
import affine
import gdal

## Cluster Setup

We'll make a cluster of Dask workers running on AKS (using an autoscaling VM scaleset) where each worker has a GPU.

In [None]:
# N_WORKERS = 4
# g = Gateway()
# options = g.cluster_options()
# options['gpu'] = True
# options['worker_memory'] = 64  # TODO: bump to 100
# options["worker_cores"] = 5
# display(options)

# cluster = g.new_cluster(options)
# client = cluster.get_client()
# cluster.scale(N_WORKERS)
# cluster
from dask_cuda import LocalCUDACluster
cluster = LocalCUDACluster(resources={"gpu": 1}, threads_per_worker=5)
client = Client(cluster)
N_WORKERS = len(client.scheduler_info()['workers'])
client

## Step 1: Data Discovery

Using the Planetary Computer's metadata query API, we can select a region of interest and get back the URLs to the assets in Azure Blob Storage. In this case, it's NAIP scenes from 2013 and 2017.

In [None]:
# This will be replaced by the query API. For now we just grab some images.
fs = adlfs.AzureBlobFileSystem(account_name="naipeuwest")
blobs_2013 = [f"/vsicurl/{fs.account_url}/{blob}" for blob in fs.glob("naip/v002/md/2013/md_100cm_2013/39076/*.tif")]
blobs_2017 = [f"/vsicurl/{fs.account_url}/{blob}" for blob in fs.glob("naip/v002/md/2017/md_100cm_2017/39076/*.tif")]

## Step 2: Mosaic / Alignment

We have many files in blob storage. We want to treat them as one dataset.

In [None]:
%time mosaic2013 = gdal.BuildVRT("mosaic-2013.vrt", blobs_2013[:12])
mosaic2013.FlushCache()

In [None]:
%time mosaic2017 = gdal.BuildVRT("mosaic-2017.vrt", blobs_2017[:12])
mosaic2017.FlushCache()

The 2013 files need to be aligned to the 2013 grid for comparision.

In [None]:
a = rasterio.open("mosaic-2013.vrt")
b = rasterio.open("mosaic-2017.vrt")
options = gdal.WarpOptions(
    outputBounds=tuple(a.bounds),
    width=a.width,
    height=a.height,
)
warped2017 = gdal.Warp("warped-2017.vrt", "mosaic-2017.vrt", options=options)
warped2017.FlushCache()

Finalize the cluster setup. Gets the VRT to the workers and ensure they're properly registered with GPUs.

In [None]:
# TODO: Make this workflow nicer.
# 1. VRT: We could perhaps pre-generate this and throw it in a bucket somewhere.
# 2. Worker resources: https://github.com/dask/distributed/pull/4456

# In the future we'll set
# options['environment'] = {"DASK_DISTRIBUTED__WORKER__RESOURCES__GPU": "1"}
# 

import os
import logging
from distributed.diagnostics.plugin import WorkerPlugin


class UploadVRTs(WorkerPlugin):
    name = "upload_vrt"
    def __init__(self, vrt_filenames):
        self.vrt_filenames = vrt_filenames
        vrt_data = {}
        for vrt_filename in vrt_filenames:
            with open(vrt_filename, "rb") as f:
                vrt_data[vrt_filename] = f.read()

        self.vrt_data = vrt_data

    def setup(self, worker):
        logger = logging.getLogger("distributed.worker")
        logger.info("Copying vrt for %s", worker)
        import subprocess
        subprocess.call(['pip', 'install', '-U', 'git+https://github.com/corteva/rioxarray.git'])
        for vrt_filename in self.vrt_filenames:
            if not os.path.exists(vrt_filename):
                with open(vrt_filename, "wb") as f:
                    f.write(self.vrt_data[vrt_filename])


# We're hacking things to set worker resources. Fixed properly in https://github.com/dask/distributed/pull/4456
async def set_resources(dask_worker, **resources):
    await dask_worker.set_resources(**resources)
    return dask_worker.total_resources


client.wait_for_workers(N_WORKERS)

# vrt
plugin = UploadVRTs(["mosaic-2013.vrt", "warped-2017.vrt"])
client.register_worker_plugin(plugin)
# gpu
resources = {"gpu": 1}
client.run(set_resources, gpu=1)
# utils
client.upload_file("utils.py")

client.wait_for_workers(N_WORKERS)

# model
model = utils.load_model("/srv/unet_both_lc.pt")
remote_model = client.scatter(model, broadcast=True)
del model

## Step 3: Preprocessing

xarray provides a convinent data structure for working with large, labeled datasets like this.

In [None]:
ds1 = rioxarray.open_rasterio("mosaic-2013.vrt", chunks=(4, 4096, 4096), lock=False)
ds1

In [None]:
ds2 = rioxarray.open_rasterio("warped-2017.vrt", chunks=(4, 4096, 4096), lock=False)
ds2

Now we have a big dataset on an aligned grid for the two time periods.

In [None]:
ds = xr.concat([ds1, ds2], dim="time")
ds

The model requires a bit of preprocessing upfront.

In [None]:
normalized = (ds - utils.mean) / utils.std
normalized = normalized.chunk(utils.CHUNKS)

# Avoid predictions on partial chunks
normalized = normalized.isel(
    y=slice(-(normalized.shape[2] % utils.CHIP_SIZE)),
    x=slice(-(normalized.shape[3] % utils.CHIP_SIZE))
)

## Step 4: Predict

The only wrinkle here is that our dataset is much larger than our available GPU memory.

In [None]:
# the prediction collapses the 4D variable of bands down to a 1D classification.
template = normalized.isel(band=0).drop_vars(["band", "spatial_ref"]).astype("uint8")

with dask.annotate(resources={'gpu': 1}):
    predictions = normalized.map_blocks(utils.predict_datarray,
                                        kwargs=dict(model=remote_model), template=template)
    predictions = predictions.rio.set_crs(normalized.rio.crs)

## Step 5: Inspect

In [None]:
change = predictions.sel(time=0) != predictions.sel(time=1)

We notice that there's some per-pixel (or sub-pixel) differences in alignment. So we'll smooth the output. We'll flag a pixel as changed only when it an all of its neighbors have changed.

In [None]:
from numba import njit, stencil
import numpy as np

@stencil(neighborhood=((-3, 2), (-3, 2)))
def _smooth(x):
    base = result = x[0, 0]
    for i in range(-3, 3):
        for j in range(-3, 3):
            if x[i, j] != base:
                result = 0
                break
    return result


@njit(parallel=True)
def smooth(x):
    return _smooth(x)


n_classes = utils.lc_cmap.N
change2 = xr.where(change, 0, n_classes * predictions.sel(time=0) + predictions.sel(time=1))
smoothed = change2.data.map_overlap(smooth, (3, 3))

In [None]:
ds_, predictions_ = client.persist([ds, predictions], optimize_graph=False)

In [None]:
change = predictions_.sel(time=0) != predictions_.sel(time=1)
change2 = xr.where(~change, 0, n_classes * predictions_.sel(time=0) + predictions_.sel(time=1))
smoothed = smooth(change2.compute().data) != 0
smoothed = xr.DataArray(smoothed, coords=change2.coords, dims=change2.dims, attrs=change2.attrs)
changed_predictions = predictions_.astype("float32").where(smoothed).persist()

In [None]:
import hvplot.xarray
import panel

kwargs = dict(x="x", y="y", cmap=utils.lc_cmap, rasterize=True,
              aggregator="mode", clim=(0, utils.lc_cmap.N - 1))

In [None]:
panel.Column(
    panel.Row(
        ds_.sel(time=0).hvplot.rgb(bands="band", rasterize=True),
        changed_predictions.sel(time=0).hvplot.image(**kwargs),
    ),
    panel.Row(
        ds_.sel(time=1).hvplot.rgb(bands="band", rasterize=True),
        changed_predictions.sel(time=1).hvplot.image(**kwargs),
    ),
)

## Step 6: Write to Azure Blob Storage

In [None]:
import azure.storage.blob
import dask.base
import dask.array

In [None]:
connection_string = os.environ["AZURE_CONNECTION_STRING"]
CRS = predictions.rio.crs
def rename_key(key):
    return 'naip/' + str(key).replace(", ", "-").replace("'", "").strip("()") + ".tif"


def write_chunk(chunk, year, slice_):
    memory_file = rasterio.io.MemoryFile()
    chunk = chunk.rio.set_crs(CRS)
    chunk.rio.to_raster(memory_file)
    memory_file.seek(0)

    client = azure.storage.blob.ContainerClient.from_connection_string(
        connection_string,
        container_name="pangeo-scratch"
    )
    name = f'naip/{year}/{dask.base.tokenize(slice_)}.tif'
    client.upload_blob(name, memory_file, length=len(memory_file), overwrite=True)
    return name

writes = []
for year in [0, 1]:
    subset = predictions[year]
    slices = dask.array.core.slices_from_chunks(subset.chunks)
    writes.extend([dask.delayed(write_chunk)(predictions[(year,) + x], year, x) for x in slices])

In [None]:
keys = client.compute(writes, optimize_graph=False)
_ = wait(keys);