# Produce Example Data
_cf. EERIE Cycle 1 Spin-up_

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import dask
from dask.distributed import Client, LocalCluster
import intake

from scipy.spatial import cKDTree
import subprocess
import re

In [2]:
data_dir = '/scratch/b/b382615/dask_example_scratch/' 

In [None]:
## Local Cluster
cluster = LocalCluster(n_workers=64, threads_per_worker=2)
client = Client(cluster)

remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
port = re.search(r':(\d+)/', client.dashboard_link).group(1)
print(f"Forward with Port = {remote_node}:{port}")

client.dashboard_link

In [None]:
cat = intake.open_catalog("https://raw.githubusercontent.com/eerie-project/intake_catalogues/main/eerie.yaml")
expid = 'eerie-control-1950'
version = 'v20231106'
model = 'icon-esm-er'
gridspec = 'native'

dat = cat['dkrz.disk.model-output'][model][expid][version]['ocean'][gridspec]

ds = dat['2d_daily_mean'](chunks={'time':2,'ncells':-1}).to_dask().sel(time=slice('2010-01-01','2040-01-01'))

In [5]:
## Fast NN Interpolator
#    NB: this is not acceptable from a conservation or accuracy point of view, but it's great to make a lot of data ;)

points_native = np.vstack((ds.lon, ds.lat)).T

# Make 0.1 degree grid
n_lat = 1800
n_lon = 3600
lons = np.linspace(-180, 180, n_lon+1); lons = lons[:-1]
lats = np.linspace(-90, 90, n_lat+1); lats = lats[:-1]
lon_mesh, lat_mesh = np.meshgrid(lons, lats)
points_regrid = np.vstack((lon_mesh.flatten(), lat_mesh.flatten())).T

tree = cKDTree(points_native)
_, nn_indices = tree.query(points_regrid)


In [6]:
lat_mesh_shape = lat_mesh.shape
nn_indices_da = xr.DataArray(nn_indices, dims=('nn_index'))
def nn_native_regrid_parallel(var_native, nn_indices):
    values_regrid = var_native[nn_indices].reshape(lat_mesh_shape)
    return values_regrid

In [7]:
def regrid(da):
    regridded = xr.apply_ufunc(nn_native_regrid_parallel, 
                                da, nn_indices_da,
                                input_core_dims=[['ncells'],['nn_index']],
                                output_core_dims=[['lat','lon']],
                                output_sizes={'lat': n_lat, 'lon': n_lon}, 
                                output_dtypes=(np.float32),
                                vectorize=True,
                                dask='parallelized')
    regridded = regridded.assign_coords(lat=('lat', lats), lon=('lon', lons))
    return regridded

In [None]:
tau_x = regrid(ds.atmos_fluxes_stress_xw)
u = regrid(ds.u.isel(depth=0))
sst = regrid(ds.to.isel(depth=0))
ds_new = xr.Dataset({'tau_x': tau_x, 'u': u, 'sst': sst})

In [None]:
ds_new

## Save to Disk:
1. zarr: Chunksize time = 1, space = -1
2. zarr: Chunksize time = 20, space = 5x
3. Series of netcdf files...

In [10]:
## 1:

ds_new = ds_new.chunk({'time':1, 'lat': -1, 'lon': -1})
encoding = {var: {'compressor': None} for var in ds_new.data_vars}
ds_new.to_zarr(data_dir+'/example_data_chunks1.zarr', mode='w', encoding=encoding)

In [None]:
## 2:

ds_new = xr.open_zarr(data_dir+'/example_data_chunks1.zarr', chunks={})
# ds_new.to_netcdf(data_dir+'/example_data.nc', mode='w')

# Specify the output file path and name
output_path = data_dir + '/example_mfdataset/'
output_name = "output_"

# Get the unique years in the dataset
years = pd.to_datetime(ds_new.time.values).year.unique()

# Create a list to store the chunked datasets
chunked_datasets = []

# Iterate over the years
for year in years:
    # Select the data for the current year
    year_data = ds_new.sel(time=str(year))
    
    # Append the year data to the list
    chunked_datasets.append(year_data)

# Save the chunked datasets using save_mfdataset
xr.save_mfdataset(
    datasets=chunked_datasets,
    paths=[output_path + output_name + f"{year}.nc" for year in years],
    mode="w",
    engine="netcdf4",
    format="NETCDF4",
    unlimited_dims=["time"],
    compute=True)

In [5]:
ds_new = xr.open_zarr(data_dir+'/example_data_chunks1.zarr', chunks={'time':20, 'lat': 36, 'lon': -1})
encoding = {var: {'compressor': None} for var in ds_new.data_vars}
ds_new.to_zarr(data_dir+'/example_data_chunks2.zarr', mode='w', encoding=encoding)

In [None]:
ds_new = xr.open_zarr(data_dir+'/example_data_chunks2.zarr', chunks={'time':100, 'lat': 4, 'lon': -1})
encoding = {var: {'compressor': None} for var in ds_new.data_vars}
ds_new.to_zarr(data_dir+'/example_data_chunks3.zarr', mode='w', encoding=encoding)