# CONUS404 Regridding (Curvilinear => Rectilinear)
Create a rectilinear grid (1D lon/lat coordinates) for a specific region. Extract spatial and temporal subset of regridded data to a netcdf file. (Extraction to netcdf may also be done for curvilinear grid.)

In [None]:
%%time
import xarray as xr
import xesmf as xe
import numpy as np
import fsspec
import hvplot.xarray
import geoviews as gv
from matplotlib import path 
import intake
import os

In [None]:
url = 'https://raw.githubusercontent.com/hytest-org/hytest/main/dataset_catalog/hytest_intake_catalog.yml'

In [None]:
# open the hytest data intake catalog
hytest_cat = intake.open_catalog(url)
list(hytest_cat)

In [None]:
# open the conus404 sub-catalog
cat = hytest_cat['conus404-catalog']
list(cat)

In [None]:
if 'SLURM_CLUSTER_NAME' in os.environ:
    dataset = 'conus404-hourly-onprem'
else:
    dataset = 'conus404-hourly-cloud'

In [None]:
ds = cat[dataset].to_dask()

In [None]:
ds

In [None]:
nc_outfile = 'CONUS404_DRB_rectilinear.nc'
bbox = [-75.9, -74.45, 38.7, 42.55]
dx = dy = 3./111.    # 3km grid
vars_out = ['T2', 'SNOW']
start = '2017-04-01 00:00'
stop  = '2017-05-01 00:00'

#### Use xESMF to regrid
xESMF is a xarray-enabled interface to the ESMF regridder from NCAR.
ESMF has options for regridding between curvilinear, rectilinear, and unstructured grids, with conservative regridding options, and much more

In [None]:
from dask.distributed import Client
client = Client()

In [None]:
def bbox2ij(lon,lat,bbox=[-160., -155., 18., 23.]):
    """Return indices for i,j that will completely cover the specified bounding box.     
    i0,i1,j0,j1 = bbox2ij(lon,lat,bbox)
    lon,lat = 2D arrays that are the target of the subset
    bbox = list containing the bounding box: [lon_min, lon_max, lat_min, lat_max]

    Example
    -------  
    >>> i0,i1,j0,j1 = bbox2ij(lon_rho,[-71, -63., 39., 46])
    >>> h_subset = nc.variables['h'][j0:j1,i0:i1]       
    """
    bbox=np.array(bbox)
    mypath=np.array([bbox[[0,1,1,0]],bbox[[2,2,3,3]]]).T
    p = path.Path(mypath)
    points = np.vstack((lon.ravel(),lat.ravel())).T   
    n,m = np.shape(lon)
    inside = p.contains_points(points).reshape((n,m))
    ii,jj = np.meshgrid(range(m),range(n))
    return min(ii[inside]),max(ii[inside]),min(jj[inside]),max(jj[inside])

##### Before we regrid to rectilinear, let's subset a region that covers our area of interest.  Becuase lon,lat are 2D arrays, we can't just use xarray to slice these coordinate variables.  So we have a routine that finds the i,j locations of a specified bounding box, and then slice on those.

In [None]:
i0,i1,j0,j1 = bbox2ij(ds['lon'].values, ds['lat'].values, bbox=bbox)
print(i0,i1,j0,j1)

In [None]:
ds_subset = ds.isel(x=slice(i0-1,i1+1), y=slice(j0-1,j1+1))

In [None]:
ds_subset = ds_subset.sel(time=slice(start,stop))

In [None]:
ds_subset

In [None]:
ds_subset.nbytes/1e9

In [None]:
da = ds_subset.T2.sel(time='2017-04-25 00:00', method='nearest')
viz = da.hvplot.quadmesh(x='lon', y='lat', geo=True, rasterize=True, cmap='turbo')
base = gv.tile_sources.OSM
base * viz.opts(alpha=0.5)

In [None]:
ds_subset.nbytes/1e9

In [None]:
%%time
ds_subset = ds_subset.chunk({'x':-1, 'y':-1, 'time':24})

In [None]:
%%time
ds_out = xr.Dataset({'lon': (['lon'], np.arange(bbox[0], bbox[1], dx)),
                     'lat': (['lat'], np.arange(bbox[2], bbox[3], dy))})

regridder = xe.Regridder(ds_subset, ds_out, 'bilinear')
regridder

In [None]:
%%time
ds_out = regridder(ds_subset[vars_out])
print(ds_out)

In [None]:
ds_out['SNOW']

In [None]:
list(ds_out.variables)

In [None]:
list(ds_out.data_vars)

In [None]:
ds_out['T2'].encoding

In [None]:
ds_out.time

In [None]:
encoding={}
for var in ds_out.variables:
    encoding[var] = dict(zlib=True, complevel=2, 
                         fletcher32=False, shuffle=True,
                         _FillValue=None
                        )

In [None]:
%%time

ds_out.to_netcdf(nc_outfile, encoding=encoding, 
                 mode='w')

In [None]:
ds_nc = xr.open_dataset(nc_outfile)

In [None]:
ds_nc

In [None]:
(ds_nc['T2']-273.15).hvplot(x='lon',y='lat', geo=True,
                rasterize=True, cmap='turbo', 
                tiles='OSM', clim=(2,15))

In [None]:
ds_outcl = ds_subset[vars_out]

In [None]:
list(ds_outcl.data_vars)

In [None]:
encoding={}
for var in ds_outcl.variables:
    encoding[var] = dict(zlib=True, complevel=2, 
                         fletcher32=False, shuffle=True,
                         _FillValue=None
                        )

In [None]:
%%time

ds_outcl.to_netcdf('CONUS404_DRB_curvilinear.nc', encoding=encoding, 
                 mode='w')

In [None]:
client.close()