# Try rechunking from Zarr-on-S3 to Zarr-on-S3

In [None]:
import fsspec
import xarray as xr
import hvplot.xarray
import numpy as np
import dask
import zarr

In [None]:
import rechunker
rechunker.__version__

In [None]:
fs = fsspec.filesystem('s3', anon=False)

In [None]:
import s3fs
s3fs.core.setup_logging('CRITICAL')

In [None]:
from dask_gateway import Gateway

In [None]:
gateway = Gateway()

In [None]:
gateway.cluster_options()

In [None]:
import sys, os
sys.path.append(os.path.join(os.environ['HOME'],'shared','users','lib'))
import ebdpy as ebd
ebd.set_credentials(profile='esip-qhub')

aws_profile = 'esip-qhub'
aws_region = 'us-west-2'
endpoint = f's3.{aws_region}.amazonaws.com'
ebd.set_credentials(profile=aws_profile, region=aws_region, endpoint=endpoint)
worker_max = 60
client,cluster = ebd.start_dask_cluster(profile=aws_profile, worker_max=worker_max, 
                                      region=aws_region, use_existing_cluster=False,
                                      adaptive_scaling=False, wait_for_cluster=False, 
                                      environment='pangeodev', worker_profile='Large Worker', 
                                      propagate_env=True)     #client.close(); cluster.close();

In [None]:
source_store = fs.get_mapper('s3://esip-qhub/usgs/COAWST/surface_vars/')

In [None]:
%%time
ds = xr.open_zarr(source_store, consolidated=True)

In [None]:
ds['Hwave']

In [None]:
def delete_s3(url):
    fs1 = fsspec.open(url, anon=False).fs
    if fs1.exists(url):
        fs1.rm(url, recursive=True)

In [None]:
s3_target_store = 's3://esip-qhub/usgs/zarr/new/step2.zarr'
s3_temp_store = 's3://esip-qhub/usgs/zarr/tmp/temp2.zarr'

In [None]:
%%time
delete_s3(s3_target_store)
delete_s3(s3_temp_store)

In [None]:
fs.ls(s3_target_store)

In [None]:
fs.ls(s3_temp_store)

In [None]:
target_store = fsspec.get_mapper(s3_target_store, anon=False)
temp_store = fsspec.get_mapper(s3_temp_store, anon=False)

### Try writing to Zarr on S3 without rechunker

In [None]:
%%time
ds[['Hwave','zeta']].isel(ocean_time=slice(0,10)).to_zarr(target_store, consolidated=True)

#### Check a value in the input dataset:

In [None]:
ds.Hwave[5,100,100].values

#### Check same value in the output dataset:

In [None]:
xr.open_dataset(target_store, engine='zarr')['Hwave'][5,100,100].values

#### Good!  They agree.  So delete this test so we can create an output dataset using rechunker

In [None]:
%%time
delete_s3(s3_target_store)
delete_s3(s3_temp_store)

In [None]:
fs.ls(s3_target_store)

### Rechunk the whole dataset using rechunker

In [None]:
max_mem = '3.0GB'    # workers are 4GB, max_mem should be set to 75% or less

In [None]:
client

In [None]:
def rechunker_wrapper(source_store, target_store, temp_store, chunks=None,
                      mem=None, consolidated=False, verbose=True):

    if isinstance(source_store, xr.Dataset):
        g = source_store  # trying to work directly with a dataset
        ds_chunk = g
    else:
        g = zarr.group(str(source_store))
        # get the correct shape from loading the store as xr.dataset and parse the chunks
        ds_chunk = xr.open_zarr(str(source_store))
        

    group_chunks = {}
    # newer tuple version that also takes into account when specified chunks are larger than the array
    for var in ds_chunk.variables:
        # pick appropriate chunks from above, and default to full length chunks for dimensions that are not in `chunks` above.
        group_chunks[var] = []
        for di in ds_chunk[var].dims:
            if di in chunks.keys():
                if chunks[di] > len(ds_chunk[di]):
                    group_chunks[var].append(len(ds_chunk[di]))
                else:
                    group_chunks[var].append(chunks[di])

            else:
                group_chunks[var].append(len(ds_chunk[di]))

        group_chunks[var] = tuple(group_chunks[var])
    if verbose:
        print(f"Rechunking to: {group_chunks}")
        print(f"mem:{mem}")
    rechunked = rechunker.rechunk(g, target_chunks=group_chunks, max_mem=mem,
                                  target_store=target_store, temp_store=temp_store)
    rechunked.execute(retries=10)
    if consolidated:
        if verbose:
            print('consolidating metadata')
        zarr.convenience.consolidate_metadata(target_store)
    if verbose:
        print('done')

In [None]:
#client.close(); cluster.shutdown()

In [None]:
#np.diff(ds.ocean_time).min()/(3600*1000*1000*1000)

In [None]:
%%time
ds_sub = ds[['Hwave','zeta']].isel(ocean_time=slice(0,654*103))

rechunker_wrapper(ds_sub, target_store=target_store, temp_store=temp_store, 
                  mem=max_mem, consolidated=True, verbose=True,
              chunks={'ocean_time':654*103, 'eta_rho':25, 'xi_rho':25})

In [None]:
#%%time 
#ds_sub = ds[['Hwave','zeta']]
#rechunker_wrapper(ds_sub, target_store=target_store, temp_store=temp_store, 
#                  mem=max_mem, consolidated=True, verbose=False,
#              chunks={'ocean_time':67375, 'eta_rho':25, 'xi_rho':25})

In [None]:
654

In [None]:
ds_new = xr.open_zarr(target_store, consolidated=True)

In [None]:
ds_new.Hwave

In [None]:
ds_new.Hwave[5,100,100].values

In [None]:
a = ds_new['Hwave'].quantile(q=np.linspace(0, 1, num=21), dim='ocean_time')                      

In [None]:
%%time
b = dask.compute(a, retries=10)[0]

In [None]:
b

In [None]:
b.to_dataset(name='Hwave').to_netcdf('Hwave_quantile3.nc', 'w')

In [None]:
c = b.sel(quantile=0.45, method='nearest')
c.where(c>0).hvplot.quadmesh(x='lon_rho', y='lat_rho', geo=True, frame_height=400,
                  rasterize=True, cmap='turbo', tiles='OSM')

In [None]:
ds_nc = xr.open_dataset('Hwave_quantile3.nc')
c = ds_nc.Hwave.sel(quantile=0.45, method='nearest')
c.where(c>0).plot()

In [None]:
#client.close(); cluster.shutdown()