In [None]:
import fsspec
import xarray as xr

In [None]:
from dask_gateway import Gateway
from dask.distributed import Client
gateway = Gateway()
# see Gateway options to use in new_cluster by doing: gateway.cluster_options()

In [None]:
gateway.list_clusters()

In [None]:
gateway.stop_cluster('83cf47c242e6446b95562bcfff2b8fa9')

In [None]:
idx=0
if not gateway.list_clusters():
    cluster = gateway.new_cluster(environment='default', profile='Small Worker')
else:
    cluster=gateway.connect(gateway.list_clusters()[idx].name)  

In [None]:
import s3fs; s3fs.__version__

In [None]:
client = Client(cluster)

In [None]:
# try a local cluster
# client = Client()

In [None]:
import configparser
import os
def set_aws_credentials(cfile=os.path.join(os.environ['HOME'],'.aws','credentials'),profile_name='default',region_name='us-east-1',endpoint='s3.amazonaws.com',verbose=False):
    '''Sets the aws credentials if not set already and profilename is default'''
    cp = configparser.ConfigParser()
    cp.read(cfile)
    os.environ['aws_access_key_id'.upper()]=cp[profile_name]['aws_access_key_id']	
    os.environ['aws_secret_access_key'.upper()]=cp[profile_name]['aws_secret_access_key']	
    os.environ['aws_profile'.upper()]=profile_name
    os.environ['aws_default_profile'.upper()]=profile_name
    os.environ['aws_s3_region'.upper()]=region_name
    os.environ['aws_s3_endpoint'.upper()]=endpoint
    os.environ['aws_default_region'.upper()]=region_name
    if verbose:
        print('export {}={}'.format('aws_access_key_id'.upper(),cp[profile_name]['aws_access_key_id']	))
        print('export {}={}'.format('aws_secret_access_key'.upper(),cp[profile_name]['aws_secret_access_key']	))

In [None]:
set_aws_credentials(profile_name='esip-qhub')

In [None]:
from dask.distributed import WorkerPlugin
import os
import uuid
import asyncio

In [None]:
class InitWorker(WorkerPlugin):
    name = "init_worker"

    def __init__(self, filepath=None, script=None):
        self.data = {}
        if filepath:
            if isinstance(filepath, str):
                filepath = [filepath]
            for file_ in filepath:
                with open(file_, "rb") as f:
                    filename = os.path.basename(file_)
                    self.data[filename] = f.read()
        if script:
            filename = f"{uuid.uuid1()}.py"
            self.data[filename] = script

    async def setup(self, worker):
        responses = await asyncio.gather(
            *[
                worker.upload_file(
                    comm=None, filename=filename, data=data, load=True
                )
                for filename, data in self.data.items()
            ]
        )
        assert all(
            len(data) == r["nbytes"]
            for r, data in zip(responses, self.data.values())
        )

In [None]:
script = f"""
import os
os.environ["AWS_ACCESS_KEY_ID"] = "{os.getenv("AWS_ACCESS_KEY_ID")}"
os.environ["AWS_SECRET_ACCESS_KEY"] = "{os.getenv("AWS_SECRET_ACCESS_KEY")}"
os.environ["AWS_DEFAULT_REGION"] = "{os.getenv("AWS_DEFAULT_REGION")}"
import s3fs
s3fs.core.setup_logging('DEBUG')
"""

In [None]:
plugin = InitWorker(script=script)
client.register_worker_plugin(plugin)

In [None]:
client

In [None]:
cluster.scale(10)

In [None]:
url = 's3://noaa-nwm-retro-v2.0-pds/full_physics/2018/201801010000.CHRTOUT_DOMAIN1.comp'
ncfile = fsspec.open(url)
dset = xr.open_dataset(ncfile.open(), engine='h5netcdf', chunks={'feature_id':600000})

In [None]:
dset.data_vars

In [None]:
dset.elevation

In [None]:
for var in dset.data_vars:
    try:
        dset[var].mean(dim='feature_id').compute()
    except:
        pass

In [None]:
fs = fsspec.filesystem('s3', anon=False)
chunked_url = 's3://esip-qhub/usgs/zarr/nwm/chunked.zarr'
zarr_chunked = fs.get_mapper(chunked_url)

In [None]:
for i in range(4):
    %time delayed_obj = dset.to_zarr(zarr_chunked, mode='w', compute=False)
    %time delayed_obj.compute()

In [None]:
client

In [None]:
#cluster.shutdown()

In [None]:
for i in range(4):
    %time delayed_obj = dset.to_zarr(zarr_chunked, mode='w', compute=False)
    %time delayed_obj.compute()