# Explore rechunked CONUS404 dataset
This dataset was created by extracting 38 variables from a collection of wrf2d output files, rechunking to better facilitate data extraction for a variety of use cases, and adding CF conventions to allow easier analysis, visualization and data extraction using Xarray

In [None]:
import fsspec
import xarray as xr
import hvplot.xarray
import intake
import os
import numpy as np
import warnings
from matplotlib import path
import panel as pn
warnings.filterwarnings('ignore')

#### Open dataset from Intake Catalog
* Automatically select on-prem dataset from /caldera if running on prem (Denali/Tallgrass)
* Automatically select cloud data on S3 if not running on prem 

To test whether we are on-prem, we see if SLURM_CLUSTER_NAME is defined.  If SLURM_CLUSTER_NAME is not defined, the user is either not on Denali/Tallgrass on the main node, which they should not be on

In [None]:
url = 'https://raw.githubusercontent.com/nhm-usgs/data-pipeline-helpers/main/conus404/conus404_intake.yml'

In [None]:
cat = intake.open_catalog(url)
list(cat)

In [None]:
if 'SLURM_CLUSTER_NAME' in os.environ:
    ds = cat['conus404-2017-onprem'].to_dask()
else:
    ds = cat['conus404-2017-cloud'].to_dask()

In [None]:
ds

In [None]:
ds.SNOW

#### Load the full domain at a specific time step

In [None]:
%%time
da = ds.SNOW.sel(time='2017-03-01 00:00').load()

In [None]:
da.hvplot.quadmesh(x='lon', y='lat', rasterize=True, 
                             geo=True, tiles='OSM', alpha=0.7, cmap='turbo')

#### Load the full time series at a specific grid cell

In [None]:
lat = 29.9659416
lon = -96.1206

In [None]:
def nearxy(x,y,xi,yi):
    ind = np.ones(len(xi),dtype=int)
    for i in range(len(xi)):
        dist = np.sqrt((x-xi[i])**2+(y-yi[i])**2)
        ind[i] = dist.argmin()
    return ind

In [None]:
def ind2sub(array_shape, ind):
    rows = int(ind.astype('int') / array_shape[1])
    cols = int(ind.astype('int') % array_shape[1]) # or numpy.mod(ind.astype('int'), array_shape[1])
    return (rows, cols)

In [None]:
[jj,ii] = ind2sub(ds.lon.shape, nearxy(ds.lon,ds.lat,[lon], [lat]))

In [None]:
%%time
da = ds.T2.isel(south_north=jj,west_east=ii).load()

In [None]:
da.hvplot(x='time', grid=True)

#### Extract a subset to NetCDF

In [None]:
bbox = [-76.63290610753754, -73.55671530588432, 37.57888442021855, 41.225532965406224]   # DRB

In [None]:
def bbox2ij(lon,lat,bbox=[-160., -155., 18., 23.]):
    """Return indices for i,j that will completely cover the specified bounding box.     
    i0,i1,j0,j1 = bbox2ij(lon,lat,bbox)
    lon,lat = 2D arrays that are the target of the subset
    bbox = list containing the bounding box: [lon_min, lon_max, lat_min, lat_max]

    Example
    -------  
    >>> i0,i1,j0,j1 = bbox2ij(lon_rho,lat_rho,[-71, -63., 39., 46])
    >>> h_subset = nc.variables['h'][j0:j1,i0:i1]       
    """
    bbox=np.array(bbox)
    mypath=np.array([bbox[[0,1,1,0]],bbox[[2,2,3,3]]]).T
    p = path.Path(mypath)
    points = np.vstack((lon.ravel(),lat.ravel())).T   
    n,m = np.shape(lon)
    inside = p.contains_points(points).reshape((n,m))
    ii,jj = np.meshgrid(range(m),range(n))
    return min(ii[inside]),max(ii[inside]),min(jj[inside]),max(jj[inside])

In [None]:
i0,i1,j0,j1 = bbox2ij(ds['lon'].values, ds['lat'].values, bbox=bbox)
print(i0,i1,j0,j1)

In [None]:
ds_drb = ds.isel(south_north=slice(j0,j1), west_east=slice(i0,i1))

In [None]:
ds_drb_timeslice = ds_drb.sel(time=slice('2017-04-01 00:00','2017-04-08 00:00'))

In [None]:
ds_drb_timeslice = ds_drb_timeslice.chunk({'south_north':-1, 'west_east':-1})  # chunk to full spatial subset domain

In [None]:
var = 'T2'
da = ds_drb_timeslice[var]

In [None]:
%%time
viz = da.hvplot.quadmesh(x='lon', y='lat', geo=True,
                    cmap='turbo', rasterize=True, tiles='OSM', title=var)
viz = pn.panel(viz, widgets={'time': pn.widgets.Select} )
pn.Column(viz).servable('DRB Explorer')

In [None]:
var_list = ['T2', 'SNOW']

ds_nc = ds_drb_timeslice[var_list]
ds_nc

In [None]:
%%time
ds_nc.to_netcdf('drb.nc', mode='w')