## 1) Import some packages

In [None]:
import xarray as xr
import hvplot.xarray
from dask.distributed import Client
import holoviews as hv
from holoviews import opts
import numpy as np
import pandas as pd
import geopandas as gpd


hv.extension("bokeh")
opts.defaults(opts.Curve(color='blue', height=500, width=650, bgcolor='lightgray', show_grid=True))

I have found Dask's distributed client to greatly improve performance (even locally).<br> 

You can start a client and go to the dashboard link to see the status of workers and tasks when code is running.<br> 

We can also create a dask cluster and connect to it through the client to scale compute

In [None]:
client = Client()
client

## 2) Converting LISFLOOD output to Zarr format for each region (for informational purposes)

In [None]:
def get_lisflood_dataset_region2():
    
    basedir = "/home/jovyan/shared/lisflood/ColumbiaUniv/"
    
    file_list = [basedir + "Region02/outputs_cal1/dis.nc",
                 basedir + "Region02/outputs_cal2/dis.nc",
                 basedir + "Region02/outputs_val1/dis.nc",
                 basedir + "Region02/outputs_val2/dis.nc"]

    # Define a function to trim the overlapping times
    def preprocess(ds):
        if ds.encoding["source"] == f"{file_list[2]}":
            return ds
        elif ds.encoding["source"] == f"{file_list[3]}":
            return ds.loc[dict(time=slice("2012-01-01 00:00", "2013-12-31 23:00"))]
        elif ds.encoding["source"] == f"{file_list[0]}":
            return ds.loc[dict(time=slice("2014-01-01 00:00", "2016-06-30 23:00"))]
        elif ds.encoding["source"] == f"{file_list[1]}":
            return ds.loc[dict(time=slice("2016-07-01 00:00", "2018-12-31 23:00"))]

    # Open all at once, calling the trim function at the same time
    ds_all = xr.open_mfdataset(
        file_list,
        preprocess=preprocess,
        chunks={"time": 3000, "lat": 100, "lon": 100},  # CHUNKING HAS A BIG EFFECT!
        parallel=True,
    )
    
    # Now clip to reg2
    gdf_reg = gpd.read_file("/home/jovyan/shared/flood_dev/sam/lisflood/regions_mask_diss.geojson")
    gdf_reg = gdf_reg[gdf_reg.DN == 2]
    ds_all.rio.write_crs("EPSG:4326", inplace=True)
    ds_all = ds_all.rio.clip(gdf_reg.geometry, "EPSG:4326", all_touched=False)

    return ds_all

In [None]:
%%time

# Add some Zarr compression to reduce file size
compressor = zarr.Blosc(cname='zstd', clevel=3)
encoding = {vname: {"compressor": compressor} for vname in ds_chnk.data_vars}

# Write to NFS
ds_chnk.to_zarr("/home/jovyan/shared/flood_dev/sam/lisflood/zarr_test/lisflood_v2_region1_dis_clip.zarr", encoding=encoding, consolidated=True)

## 3) Now push to Zarr file (directory) to GCS

#### You can write Zarr files directly to a GCS bucket from xarray but I've found it's much faster to use gsutil
ex: `gsutil -m cp lisflood_v2_region1_dis_clip.zarr gs://sam-temp-dev/lisflood_region1`

## 4) Now we can access the data directly from the GCS bucket

#### NOTE: I've noticed it's much faster if you disconnect from the VPN first!

In [None]:
# For reading data, you can start a new client with a different configuration for better performance
# client.close()
client = Client(n_workers=4, threads_per_worker=4, memory_limit='4GB', processes=False)
client

This example shows the difference in performance in reading the data based on two different chunking schemes.  

The first zarr store is chunked in the lat/lon dimensions as well as the time dimension (the data itself is a stack of 2-D grids).<br> This gives reasonable read performance in requesting the entire time series at a single point, as well as slicing a single time step across all points.

The second is only chunked in the lat/lon dimensions while keeping the time series in one huge chunk.<br>  This makes it faster to retrieve the entire time series at a single point, but causes memory problems when trying to slice a single time step.

The chunking scheme should be created to optimize the way data will be retrieved.  Zarr is nice for this because it's so flexible.<br>
You can try accesing the data with the two different chunking schemes below to see the difference in performance

In [None]:
chunked_by_time_and_latlon = "gs://sam-temp-dev/lisflood_hokkaido_v2"
chunked_by_latlon_only = "gs://sam-temp-dev/lisflood_hokkaido_v2_2"

ds_dis = xr.open_zarr(chunked_by_time_and_latlon, consolidated=True) 
ds_dis

#### Here we request the entire time series at a single point

In [None]:
%%time
# NOTE: DISCONNECT FROM VPN FOR LOCAL PROCESSING!!

# region 4 (Hokkaido)
lat = 43.49516727
lon = 141.89472086
# lat = [43.49516727, 44.99505638]
# lon = [141.89472086, 141.79640070]

# # region2 (okayama)
# lat = 34.61490570
# lon = 133.96546034

ds_dis.dis.sel(lat=lat, lon=lon, method="nearest").hvplot()

#### Here we slice one array at a single time step, or you can even request several time steps

In [None]:
%%time
# Get a single time step
# ds_dis.dis.sel(time="2011-09-03 16:00").hvplot.quadmesh(x='lon',
#                                                         y='lat',
#                                                         title='dis',
#                                                         geo=True,
#                                                         width=650,
#                                                         height=600,
#                                                         rasterize=True,
#                                                         project=True,
#                                                         cmap="bmw",
#                                                         clim=(0, 500),
#                                                         tiles='EsriImagery')
# Get several time steps
ds_dis.dis.sel(time=slice("2011-09-02 16:00", "2011-09-03 16:00")).hvplot.quadmesh(x='lon',
                                                                                   y='lat',
                                                                                   title='dis',
                                                                                   geo=True,
                                                                                   width=650,
                                                                                   height=600,
                                                                                   rasterize=True,
                                                                                   project=True,
                                                                                   cmap="bmw",
                                                                                   clim=(0, 500),
                                                                                   tiles='EsriImagery')

## 5) Historical JMA as Zarr

In [None]:
%%time
ds_jma = xr.open_zarr("gs://sam-temp-dev/320-420-201201-2", consolidated=True)
ds_jma

In [None]:
%%time
lat = ds_jma.latitude.values
lon = ds_jma.longitude.values

In [None]:
%%time
ds_jma.rainrate.isel(time=slice(600, 700)).hvplot.quadmesh(x='longitude',
                                            y='latitude',
                                            title='rainrate',
                                            geo=True,
                                            width=650,
                                            height=600,
                                            rasterize=True,
                                            project=True,
                                            cmap="bmy",
                                            clim=(0, 2),
                                            tiles='EsriImagery')

In [None]:
%%time
ds_jma.rainrate.sel(time="2012-01-01 01:00:00").values

In [None]:
%%time
ds.rainrate.sum(dim="time").values

In [None]:
%%time
ds.rainrate.sum(dim="time").values

## 6) Reading parquet files with dask

In [None]:
client.close()

In [None]:
%%time
import dask.dataframe as dd
ddf_bo = dd.read_parquet("gs://sam-temp-dev/parquet/test_grid_1_baseline.parquet")  # ~32 MB vs. 250 MB .xyz
ddf_bo.head()

## 7) Reading and Visualizing SCHISM output

In [None]:
import datashader.transfer_functions as tf
import datashader.utils as du
from datashader.colors import inferno, viridis
import datashader as dsh
import holoviews as hv
from holoviews import opts
from holoviews.operation.datashader import datashade, dynspread, rasterize
import geoviews as gv
import geoviews.feature as gf
import cartopy.crs as ccrs
from matplotlib import cm

gv.extension('bokeh')
gv.output(size=200)

In [None]:
client.close()
client = Client(n_workers=4, threads_per_worker=4, memory_limit='4GB', processes=False)
client

In [None]:
%%time
ds_sch = xr.open_zarr("gs://sam-temp-dev/schout_w7xb7_05.zarr", consolidated=True)
ds_sch

In [None]:
%%time
max_depth = ds_sch.isel(time=slice(0, 190)).elev.max(dim="time") + ds_sch.isel(time=0).depth
z = max_depth.values

In [None]:
z.shape

In [None]:
%%time
x = ds_sch.isel(time=0).SCHISM_hgrid_node_x.values
y = ds_sch.isel(time=0).SCHISM_hgrid_node_y.values
faces = ds_sch.isel(time=0).SCHISM_hgrid_face_nodes.values-1

In [None]:
%%time
PLINTH = 0.01
cities = ["Chiba"]
POLYS_107 = "/Users/slamont/japan_gis/geo_boundaries_shp/cities_107_with_grid_num.geojson"

gv_map = trimesh_max_depth_map(x, y, z, faces, "Zarr-based Map")
gv_map

In [None]:
def trimesh_max_depth_map(x, y, z, faces, title):

    # Build the standard static map (max depth over some interval)
    z[z < PLINTH] = np.nan

    df_verts = pd.DataFrame({'x': x, 'y': y, 'z': z})
    df_tris = pd.DataFrame(faces[:, 0:3],columns=['v0','v1','v2'])

    gv_basemap = gv.tile_sources.CartoLight

    gdf_107 = gpd.read_file(POLYS_107)
    mask = gdf_107["E-Name"].isin(cities)
    gdf_107 = gdf_107[mask]
    xmin, ymin, xmax, ymax = gdf_107.total_bounds
    gv_poly_107 = gv.Polygons(gdf_107).opts(alpha=0.15)


    gv_trimap = gv_basemap * gv_poly_107 * rasterize(gv.TriMesh((df_tris, df_verts), crs=ccrs.PlateCarree())).options(
        cmap=cm.Spectral_r,
        colorbar=True,
        clim=(PLINTH, 10.),
        clabel='meter',
        width=520,
        height=440,
        title=f'{title}',
        tools=['hover']).redim.range(x=(xmin, xmax),y=(ymin, ymax))

    return gv_trimap