# Run the slab model

## Parameters

In [None]:
# parameters

GFS_zarr_store = "tmp_GFS.zarr"
slab_zarr_store = "tmp_slab.zarr"

dask_kwargs = {"n_workers": 1, "threads_per_worker": 2, "memory_limit": 6e9}

## Tech preamble

Import modules and spin up Dask cluster.

In [None]:
import xarray as xr
from datetime import datetime, timedelta
from dask.distributed import Client
import hvplot.xarray
from pathlib import Path

The memory requirements of the integration are currently relatively high. So we'll go for one single-threaded Worker with enough memory.

If we'll do this with Github Actions, there will be a total of [6GB memory](https://docs.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners#supported-runners-and-hardware-resources) for all we do. So let's stay below this even in development.

In [None]:
client = Client(**dask_kwargs)
client

## Loading the data

In [None]:
ds = xr.open_zarr(GFS_zarr_store)

## Slab model

In [None]:
import numpy as np

In [None]:
def upsample_data(data_obj, time_delta_ns=3600.0e9, time_name="time"):
    data_obj.coords[time_name] = data_obj.coords[time_name].astype("float")
    data_obj = data_obj.interp(
        time=np.arange(data_obj.coords[time_name][0],
                       data_obj.coords[time_name][-1],
                       time_delta_ns),
        method="slinear")
    data_obj.coords[time_name] = data_obj.coords[time_name].astype("datetime64[ns]")

    return data_obj

In [None]:
taux = upsample_data(ds["taux"])
tauy = upsample_data(ds["tauy"])

In [None]:
taux.nbytes / 1e6

In [None]:
def filter_windstress(taux, tauy):

    # we need a uniform time step
    assert taux.time.diff("time").astype("float").std("time") / 1e9 < 1e-4
    dt = (taux.time.diff("time").astype("float").mean("time") / 1e9).data  # ns --> s
    epsilon = 1 / 5 / 24 / 3600
    f = 2 * 7.2921e-5 * np.sin(np.deg2rad(taux.coords["lat"]))
    rho = 1035
    H = 20

    c_2 = 2 * dt / rho / H
    d_1 = - 2j * dt * f
    d_2 = 1 - 2 * dt * epsilon

    T = taux.astype("float32") + 1j * tauy.astype("float32")
    T = xr.where(~T.isnull(), T, 0)

    import numba

    @numba.jit
    def integrate(T, d_1):
        q = np.zeros_like(T)
        for l in range(2, T.shape[0]):
            q[l, ...] = (d_2 * q[l-2, ...] + 
                         d_1[l-1, ...] * q[l-1, ...] + 
                         c_2 * T[l-2, ...])
        return q

    q = xr.apply_ufunc(integrate, T, d_1 + 0.0 * T,
                       vectorize=True,
                       input_core_dims=[['time'], ['time']],
                       output_core_dims=[['time']],
                       output_dtypes=[np.complex],
                       dask='parallelized')
    q = q.where(xr.ufuncs.logical_not(T == 0))
        
    slab_u = xr.ufuncs.real(q).astype("float32")
    slab_v = xr.ufuncs.imag(q).astype("float32")
    
    slab_u = slab_u.rename("slab_u")
    slab_v = slab_v.rename("slab_v")
    
    slab_umag = (slab_u**2 + slab_v**2)**0.5
    slab_umag = slab_umag.rename("slab_umag")

    return slab_u, slab_v, slab_umag

In [None]:
ds_slab = xr.Dataset()
ds_slab["u_slab"], ds_slab["v_slab"], ds_slab["umag_slab"] = filter_windstress(taux, tauy)
ds_slab = ds_slab.sel(time=ds.time, method="nearest")
ds_slab

In [None]:
%%time

ds_slab.to_zarr(slab_zarr_store, mode="w")

In [None]:
ds_slab = xr.open_zarr(slab_zarr_store)

In [None]:
(
    ds["taux"].sel(lon=360-23, lat=12, method="nearest").hvplot.line(label="taux")
    * ds["tauy"].sel(lon=360-23, lat=12, method="nearest").hvplot.line(label="tauy")
    + ds_slab["umag_slab"].sel(lon=360-23, lat=12, method="nearest").hvplot.line(label="NIA")
).cols(1)

In [None]:
!echo "Finished: $(date -Ins)"

---
See https://github.com/willirath/nia-prediction-low-latitutdes for details.