## Parameters

In [None]:
# parameters

slab_model_H = 1  # Set to 1m and scale later

GFS_zarr_store = "tmp_GFS.zarr"
slab_zarr_store = "tmp_slab.zarr"

dask_kwargs = {"n_workers": 1, "threads_per_worker": 3, "memory_limit": 6e9}

dask_worker_config = {
    "distributed.worker.memory.target": 0.80,  # target fraction to stay below
    "distributed.worker.memory.spill": 0.85,  # fraction at which we spill to disk
    "distributed.worker.memory.pause": 0.90,  # fraction at which we pause worker threads
    "distributed.worker.memory.terminate": 0.95,  # fraction at which we terminate the worker
}

## Tech preamble

Import modules and spin up Dask cluster.

In [None]:
import xarray as xr
import numpy as np
from datetime import datetime, timedelta
import dask
from dask.distributed import Client
import hvplot.xarray

In [None]:
# set to less agressive memory management
dask.config.set(dask_worker_config)

# start cluster and connect
client = Client(**dask_kwargs)
client

## Load the data

In [None]:
ds = xr.open_zarr(GFS_zarr_store)
display(ds)

## Interpolate the data

We need a timestep that is well below the interial period (1/2 day at pole). To be on the safe side, let's go for one hour.

In [None]:
def upsample_data(data_obj, time_delta_ns=3600.0e9, time_name="time"):
    """Linearly interpolate data to equidistant time steps."""
    data_obj.coords[time_name] = data_obj.coords[time_name].astype("float")
    data_obj = data_obj.interp(
        time=np.arange(data_obj.coords[time_name][0],
                       data_obj.coords[time_name][-1],
                       time_delta_ns),
        method="slinear")
    data_obj.coords[time_name] = data_obj.coords[time_name].astype("datetime64[ns]")

    return data_obj

In [None]:
taux = upsample_data(ds["taux"])
tauy = upsample_data(ds["tauy"])

In [None]:
print(f"wind-stress data is {2 * taux.nbytes / 1e6} MB")

## Define the slab model

We'll integrate the slab-ocean model of [Pollard & Millard (1970)](https://doi.org/10.1016/0011-7471(70)90043-4)

$$
\frac{\partial (u,v)}{\partial t} + f\,(-v,u)
  = \frac{(\tau^x, \tau^y)}{H\rho} - \epsilon\,(u,v)
$$

where $(u, v)$ are the ocean velocity components, $f$ is the Coriolis parameter,
$(\tau^x, \tau^y)$ are the components of the wind stress, $\rho$ is the density of sea water,
$H$ is the mixed-layer depth, and $\epsilon$ is a linear damping coefficient.

With the complex $q=u+iv$ and $T=\tau^x + i\tau^y$ the model equation becomes:

$$
\frac{\partial q}{\partial t} + if\,q
  = \frac{T}{H\rho} - \epsilon\,q
$$

As [d'Asaro (1985)](https://doi.org/10.1175/1520-0485(1985)015%3C1043:TEFFTW%3E2.0.CO;2) points out, the response of $q$ to wind variations can be separated into a transient inertial response and an Ekman current and the Ekman current can be removed by integrating:
$$
\frac{\partial q}{\partial t} + (\epsilon + if)\,q
  = - \frac{1}{(\epsilon + if)\rho\,H} \frac{\partial T}{\partial t}
$$

This can be discretized to

$$
q_l = d_1 \, q_{l-1} + d_2 \, q_{l-2} + c_0 \, T_l + c_2 \, T_{l-2}
$$

with $d_1 = - 2if \Delta t$, $d_2 = 1 - 2 \epsilon \Delta t$, $c_0 = - 1 / ((\epsilon + if)\rho H)$, $c_2 = - c_0$.

In [None]:
def filter_windstress(
    taux, tauy,
    epsilon=1 / 5 / 24 / 3600,  # [1/s]
    rho=1035,  # [kg/m3]
    H=10,  # [m]
    f=None,  # [1/s]
):

    # we need a uniform time step
    assert taux.time.diff("time").astype("float").std("time") / 1e9 < 1e-4
    dt = (taux.time.diff("time").astype("float").mean("time") / 1e9).data  # ns --> s
    
    # maybe generate f from latitude info
    if f is None:
        f = 2 * 7.2921e-5 * np.sin(np.deg2rad(taux.coords["lat"]))

    # complex wind stress, masked with zeros where missing
    T = taux.astype("float32") + 1j * tauy.astype("float32")
    T = xr.where(~T.isnull(), T, 0)

    # integration coefficients
    c_0 = - 1 / (epsilon + 1j * f) / rho / H
    c_2 = - c_0
    d_1 = - 2j * f * dt
    d_2 = 1 - 2 * epsilon * dt
    
    # broadcast
    c_0 = xr.broadcast(xr.DataArray(c_0), T)[0]
    c_2 = xr.broadcast(xr.DataArray(c_2), T)[0]
    d_1 = xr.broadcast(xr.DataArray(d_1), T)[0]
    d_2 = xr.broadcast(xr.DataArray(d_2), T)[0]

    import numba

    # helper function doing the actual integration
    @numba.jit
    def integrate(T, d_1, d_2, c_0, c_2):
        q = np.zeros_like(T)
        for l in range(2, T.shape[0]):
            q[l, ...] = (
                d_1[l-1, ...] * q[l-1, ...]
                + d_2[l-2, ...] * q[l-2, ...]
                + c_0[l, ...] * T[l, ...]
                + c_2[l-2, ...] * T[l-2, ...]
            )
        return q
    
    # apply integration to all data
    q = xr.apply_ufunc(integrate, T, d_1, d_2, c_0, c_2,
                       vectorize=True,
                       input_core_dims=[['time'], ['time'],  ['time'], ['time'], ['time']],
                       output_core_dims=[['time']],
                       output_dtypes=[np.complex],
                       dask='parallelized')
    
    # mask for undefined wind stress again
    q = q.where(xr.ufuncs.logical_not(T == 0))
    
    # mask for +/- 5 deg around equator
    q = q.where(abs(q.coords["lat"]) > 4.0)
        
    # extract u and v from complex q, remove mean, and calc speed
    slab_u = xr.ufuncs.real(q).astype("float32")
    slab_u = slab_u.rename("slab_u")
    
    slab_v = xr.ufuncs.imag(q).astype("float32")
    slab_v = slab_v.rename("slab_v")
    
    slab_umag = (slab_u**2 + slab_v**2)**0.5
    slab_umag = slab_umag.rename("slab_umag")

    return slab_u, slab_v, slab_umag

## Run the slab model

We'll create the dataset with the slab model output and sub-sample the model to the original time resolution of the wind data.

The actual computation is triggered and performed parallel when we store the data.

In [None]:
# define output data
ds_slab = xr.Dataset()
ds_slab["u_slab"], ds_slab["v_slab"], ds_slab["umag_slab"] = filter_windstress(taux, tauy, H=slab_model_H)

# subsample
ds_slab = ds_slab.sel(time=ds.time, method="nearest")

# add other parameters
ds_slab.attrs["slab_model_H"] = slab_model_H

In [None]:
display(ds_slab)

In [None]:
%%time

ds_slab.to_zarr(slab_zarr_store, mode="w")

## Have a brief look at the data

In [None]:
ds_slab = xr.open_zarr(slab_zarr_store)

In [None]:
(
    ds["taux"].sel(lon=360-23, lat=12, method="nearest").hvplot.line(label="taux")
    * ds["tauy"].sel(lon=360-23, lat=12, method="nearest").hvplot.line(label="tauy")
    + ds_slab["u_slab"].sel(lon=360-23, lat=12, method="nearest").hvplot.line(label="slab U")
    * ds_slab["v_slab"].sel(lon=360-23, lat=12, method="nearest").hvplot.line(label="slab V")
    * ds_slab["umag_slab"].sel(lon=360-23, lat=12, method="nearest").hvplot.line(label="slab UMAG")
).cols(1)

In [None]:
(
    ds["taux"].sel(lon=360-23, lat=22, method="nearest").hvplot.line(label="taux")
    * ds["tauy"].sel(lon=360-23, lat=22, method="nearest").hvplot.line(label="tauy")
    + ds_slab["u_slab"].sel(lon=360-23, lat=22, method="nearest").hvplot.line(label="slab U")
    * ds_slab["v_slab"].sel(lon=360-23, lat=22, method="nearest").hvplot.line(label="slab V")
    * ds_slab["umag_slab"].sel(lon=360-23, lat=22, method="nearest").hvplot.line(label="slab UMAG")
).cols(1)

In [None]:
(
    ds["taux"].sel(lon=360-23, lat=-32, method="nearest").hvplot.line(label="taux")
    * ds["tauy"].sel(lon=360-23, lat=-32, method="nearest").hvplot.line(label="tauy")
    + ds_slab["u_slab"].sel(lon=360-23, lat=-32, method="nearest").hvplot.line(label="slab U")
    * ds_slab["v_slab"].sel(lon=360-23, lat=-32, method="nearest").hvplot.line(label="slab V")
    * ds_slab["umag_slab"].sel(lon=360-23, lat=-32, method="nearest").hvplot.line(label="slab UMAG")
).cols(1)

In [None]:
(
    ds["taux"].sel(lon=95, lat=-5, method="nearest").hvplot.line(label="taux")
    * ds["tauy"].sel(lon=95, lat=-5, method="nearest").hvplot.line(label="tauy")
    + ds_slab["u_slab"].sel(lon=95, lat=-5, method="nearest").hvplot.line(label="slab U")
    * ds_slab["v_slab"].sel(lon=95, lat=-5, method="nearest").hvplot.line(label="slab V")
    * ds_slab["umag_slab"].sel(lon=95, lat=-5, method="nearest").hvplot.line(label="slab UMAG")
).cols(1)

In [None]:
ds_slab["v_slab"].sel(lon=360-23, lat=12, method="nearest")

---

In [None]:
!echo "Finished: $(date -Ins) (UTC)"

---
See https://github.com/willirath/nia-prediction-low-latitudes for details.