In [None]:
import xarray as xr
import fsspec
import numpy as np

### Start a Dask cluster
This is not required, but speeds up computations. Once can start a local cluster by just doing:
```
from dask.distributed import Client
client = Client()
```
but there are [many other ways to set up Dask clusters](https://docs.dask.org/en/latest/setup.html) that can scale larger than this. 

Since we used [Qhub](https://www.quansight.com/post/announcing-qhub) to install JupyterHub with a Dask Gateway running on Kubernetes, we can start a Dask cluster (with a specified environment and worker profile), scale it, and connect to it thusly:

In [None]:
import os
import sys
sys.path.append(os.path.join(os.environ['HOME'],'shared','users','lib'))
import ebdpy as ebd

ebd.set_credentials(profile='esip-qhub')

profile = 'esip-qhub'
region = 'us-west-2'
endpoint = f's3.{region}.amazonaws.com'
ebd.set_credentials(profile=profile, region=region, endpoint=endpoint)
worker_max = 10
client,cluster = ebd.start_dask_cluster(profile=profile,worker_max=worker_max, 
                                      region=region, use_existing_cluster=True,
                                      adaptive_scaling=False, wait_for_cluster=False, 
                                      environment='pangeo', worker_profile='Pangeo Worker', propagate_env=True)

In [None]:
cluster

In [None]:
cluster.scale(20)

Open Zarr datasets in Xarray using a mapper from fsspec.  We use `anon=True` for free-access public buckets like the AWS Open Data Program, and `requester_pays=True` for requester-pays public buckets. 

In [None]:
url = 's3://noaa-nwm-retro-v2-zarr-pds'

We don't specify a profile here because we've passed the AWS credentials via environment variables to the cluster:

In [None]:
fs = fsspec.filesystem('s3', anon=False)

In [None]:
%%time
ds = xr.open_zarr(fs.get_mapper(url), consolidated=True)

In [None]:
ds

In [None]:
ds

In [None]:
print(f'Variable size: {ds[var].nbytes/1e12:.1f} TB')

In [None]:
idx = (ds.latitude > 41.0) & (ds.latitude < 51.0) & (ds.longitude > -75.0) & (ds.longitude < -62.0)

In [None]:
ds_out = ds[['streamflow']].isel(feature_id=idx).isel(time=slice(0,672))

In [None]:
%%time
ds_out = ds[['streamflow']].isel(feature_id=idx).sel(time=slice('2000-01-01',None))

In [None]:
def gchunks(ds_chunk, chunks):
    group_chunks = {}

    for var in ds_chunk.variables:
        # pick appropriate chunks from above, and default to full length chunks for dimensions that are not in `chunks` above.
        group_chunks[var] = []
        for di in ds_chunk[var].dims:
            if di in chunks.keys():
                if chunks[di] > len(ds_chunk[di]):
                    group_chunks[var].append(len(ds_chunk[di]))
                else:
                    group_chunks[var].append(chunks[di])

            else:
                group_chunks[var].append(len(ds_chunk[di]))

        ds_chunk[var] = ds_chunk[var].chunk(tuple(group_chunks[var]))
        group_chunks[var] = {'chunks':tuple(group_chunks[var])}
    return group_chunks

In [None]:
encoding = gchunks(ds_out, {'time':672, 'feature_id':10000})

In [None]:
%%time
ds_out.to_zarr(fs.get_mapper('esip-qhub/usgs/rsignell/testing/zarr/gulf_of_maine'), 
                mode='w', encoding=encoding)