# Calculation wave climate from COAWST  forecast

We rechunk the original netcdf data to time series orientation in Zarr using rechunker, then use xarray map_blocks to process each block along the time dimension. 


In [None]:
%%time
import dask.distributed
from dask.distributed import Client, performance_report
import numpy as np
import xarray as xr
import hvplot.xarray
import pandas as pd
from rechunker import rechunk
import fsspec
from pathlib import Path
import shutil
import zarr
import time

In [None]:
ctype = 'standard'
#ctype = 'usgs'

In [None]:
if ctype == 'usgs':
    from dask_jobqueue import SLURMCluster
    cluster = SLURMCluster(cores=36, memory='128GiB', 
                    interface='ib0', 
                    local_directory='$SCRATCH',
                    job_extra=['--qos usgs', '--partition usgs'], walltime='03:00:00')
    print(cluster.job_script())

In [None]:
if ctype == 'standard':
    from dask_jobqueue import SLURMCluster
    cluster = SLURMCluster(cores=36, memory='128GiB', 
                project='science', interface='ib0', 
                local_directory='$SCRATCH',
                queue='compute', walltime='03:00:00')
    print(cluster.job_script())

In [None]:
client = Client(cluster)

In [None]:
cluster.scale(jobs=1)

In [None]:
cluster

In [None]:
fs = fsspec.filesystem('')

In [None]:
#fs.glob('/vortexfs1/share/usgs-share/Projects/COAWST/2020/coawst_us_*.nc')

In [None]:
%%time
fgood = fs.glob('/vortexfs1/share/usgs-share/Projects/COAWST/20*/coawst_us_20*.nc')

In [None]:
len(fgood)

In [None]:
fgood = np.sort(fgood)
print(fgood[0])
print(fgood[-1])

In [None]:
def rechunker_wrapper(source_store, target_store, temp_store, chunks=None, mem="2GiB", consolidated=True, verbose=True):

    # convert str to paths
    def maybe_convert_to_path(p):
        if isinstance(p, str):
            return Path(p)
        else:
            return p

    source_store = maybe_convert_to_path(source_store)
    target_store = maybe_convert_to_path(target_store)
    temp_store = maybe_convert_to_path(temp_store)

    # erase target and temp stores
    if temp_store.exists():
        shutil.rmtree(temp_store)

    if target_store.exists():
        shutil.rmtree(target_store)


    if isinstance(source_store, xr.Dataset):
        g = source_store  # trying to work directly with a dataset
        ds_chunk = g
    else:
        g = zarr.group(str(source_store))
        # get the correct shape from loading the store as xr.dataset and parse the chunks
        ds_chunk = xr.open_zarr(str(source_store))
        

    # convert all paths to strings
    source_store = str(source_store)
    target_store = str(target_store)
    temp_store = str(temp_store)

    group_chunks = {}
    # newer tuple version that also takes into account when specified chunks are larger than the array
    for var in ds_chunk.variables:
        # pick appropriate chunks from above, and default to full length chunks for dimensions that are not in `chunks` above.
        group_chunks[var] = []
        for di in ds_chunk[var].dims:
            if di in chunks.keys():
                if chunks[di] > len(ds_chunk[di]):
                    group_chunks[var].append(len(ds_chunk[di]))
                else:
                    group_chunks[var].append(chunks[di])

            else:
                group_chunks[var].append(len(ds_chunk[di]))

        group_chunks[var] = tuple(group_chunks[var])
    if verbose:
        print(f"Rechunking to: {group_chunks}")
    rechunked = rechunk(g, group_chunks, mem, target_store, temp_store=temp_store)
    rechunked.execute(retries=10)
    if consolidated:
        if verbose:
            print('consolidating metadata')
        zarr.convenience.consolidate_metadata(target_store)
    if verbose:
        print('removing temp store')
    shutil.rmtree(temp_store)
    if verbose:
        print('done')

In [None]:
def drop(ds):
    drop_list=[]
    for var in ds.variables:
        if len(ds[var].dims) == 0:
            drop_list.append(ds[var].name) 
        if var in ['wetdry_mask_psi','wetdry_mask_rho','wetdry_mask_u','wetdry_mask_v']:
            drop_list.append(ds[var].name)

    return ds.drop(drop_list)

In [None]:
# prefix = '/vortexfs1/scratch/aretxabaleta/rechunk/zarr'
#prefix = '/vortexfs1/usgs/rsignell/alfredo/rechunk/zarr'
prefix = '/vortexfs1/scratch/rsignell/alfredo/rechunk/zarr'

In [None]:
zarr_step = f'{prefix}/step'
zarr_temp = f'{prefix}/tmp'
zarr_chunked = f'{prefix}/coawst_fc'

In [None]:
max_mem = '16GB'
time_chunk_size = 144
time_steps_per_file = 12
# x_chunk_size = 300
# y_chunk_size = 300
x_chunk_size = 100
y_chunk_size = 100

In [None]:
nt_chunks = int(np.ceil(len(fgood)*time_steps_per_file/time_chunk_size))
nt_chunks

In [None]:
%%time
First=True
try:
    shutil.rmtree(zarr_chunked, ignore_errors=False, onerror=None)
except:
    pass

In [None]:
%%time

for i in range(1):
#for i in range(nt_chunks):
#for i in range(9,nt_chunks):
#for i in range(nt_chunks-3, nt_chunks):
    start = time.time()
    print(i)
    istart = i * int(time_chunk_size/time_steps_per_file)
    istop = int(np.min([(i+1) * int(time_chunk_size/time_steps_per_file), len(fgood)]))
    
    ds = xr.open_mfdataset(fgood[istart:istop], chunks={'ocean_time':time_steps_per_file}, decode_timedelta=False,
                           data_vars="minimal", coords="minimal", compat="override", parallel=True,
                           preprocess=drop)
    
    rechunker_wrapper(ds, zarr_step, zarr_temp, mem=max_mem, consolidated=True, verbose=False,
                  chunks={'s_rho':1, 's_w':1, 'ocean_time':144, 'ND':1, 'Nbed':1,
                         'eta_rho':100, 'xi_rho':100, 'eta_u':100, 'xi_u':100,
                         'eta_v':100, 'xi_v':100})
    

    # read back in the zarr chunk rechunker wrote
    ds = xr.open_zarr(zarr_step, consolidated=True)

    if First is True:
        ds.to_zarr(zarr_chunked, consolidated=True, mode='w')
        First = False
    else:
    #    ds.to_zarr(zarr_chunked, consolidated=True, append_dim='time')
        ds.to_zarr(zarr_chunked, consolidated=True, append_dim='ocean_time')
    print(f'completed in {start - time.time()} seconds')

#### Check resulting dataset 

In [None]:
ds_woopwoop = xr.open_zarr(zarr_chunked, consolidated=True)

In [None]:
%%time
ds_woopwoop.temp[:,-1,180,500].hvplot(x='ocean_time', grid=True)

In [None]:
client.close(); cluster.close()

#### Some extra cells below to explore if something goes wrong with a particular zarr chunk

In [None]:
ds = xr.open_mfdataset(fgood[istart:istart+6], chunks={'ocean_time':time_steps_per_file}, decode_timedelta=False,
                           data_vars="minimal", coords="minimal", compat="override", parallel=True,
                           preprocess=drop)

In [None]:
fgood[istart:istop]

In [None]:
ds.dims['ocean_time']

In [None]:
i=9
istart = i * int(time_chunk_size/time_steps_per_file)
istop = int(np.min([(i+1) * int(time_chunk_size/time_steps_per_file), len(fgood)]))

In [None]:
t = []
for f in fgood[istart:istop]:
    ds = xr.open_dataset(f)
    print(f)
    print(ds.ocean_time.values)

In [None]:
import os
os.rename('/vortexfs1/share/usgs-share/Projects/COAWST/2012/coawst_us_20120902_13.nc',
          '/vortexfs1/share/usgs-share/Projects/COAWST/bad_files/coawst_us_20120902_13.nc')