This notebook estimates parameters of a drift function by solving a linear least-squares problem.

Parameters are estimated using global satellite data and drifters velocities for the period 2010–2020:

- Geostrophic currents from DUACS (https://doi.org/10.48670/moi-00148),
- Stokes drift from WAVERYS/MFWAM (https://doi.org/10.48670/moi-00022),
- Wind stress and 10 m velocity from ERA5 (https://doi.org/10.48670/moi-00185),
- Drogued-SVP drifters from the GPD (https://doi.org/10.25921/x46c-3620).

In [None]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import cartopy.crs as ccrs
import cmocean.cm as cmo
from IPython.display import display, Math
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrd
import matplotlib.pyplot as plt
import numpy as np
import optax
import xarray as xr


from src.spherical_harmonics import gen_associated_legendre, sph_harm_y

In [None]:
gdp_ds = xr.open_zarr("data/gdp_interp.zarr")

# Problem overview

## 1. Original formulation from Rio *et al.* (2014)

The drift is modeled as:
$$
\vec{v}_d(\lambda, \phi) = \vec{u}_g(\lambda, \phi) + \beta(\lambda, \phi) \vec{\tau}(\lambda, \phi) e^{i \theta(\lambda, \phi)}
$$
where:

- $\vec{v}_d$ is the drifter velocity at the lat-lon point $(\lambda, \phi)$,
- $\vec{u}_g$ is the geostrophic current velocity,
- $\vec{\tau}$ is the wind stress at the ocean surface,
- $\beta$ and $\theta$ are the parameters of the empirical Ekman model.

$\beta$ and $\theta$ are subject to the following physical constraints: 
$$
\beta(\lambda, \phi) > 0, \ \forall (\lambda, \phi) \\
\theta(\lambda, \phi) \in [0, 2\pi], \ \forall (\lambda, \phi)
$$

## 2. Spatial dependency using spherical harmonics decompositions

In Rio **et al.**, $\beta$ and $\theta$ are fitted per spatial bins of 4° $\times$ 4°.
We propose to rather model their spatial dependency using spherical harmonics decompositions:
$$
\beta(\lambda, \phi) = \sum_l \sum_m \beta_m^l Y_m^l(\lambda, \phi) \\
\theta(\lambda, \phi) = \sum_l \sum_m \theta_m^l Y_m^l(\lambda, \phi)
$$

The resulting optimization problem is:
$$
\underset{\beta_m^l, \theta_m^l}{\operatorname{argmin}} \left \| \vec{v}_d(\lambda, \phi) - \vec{u}_g(\lambda, \phi) - \beta(\lambda, \phi) \vec{\tau}(\lambda, \phi) e^{i \theta(\lambda, \phi)} \right \|_2^2
$$
which is nonlinear and nonconvex.

## 3. Convex reformulation

Introducing the complex field:
$$
\alpha(\lambda, \phi) = \beta(\lambda, \phi) e^{i \theta(\lambda, \phi)}
$$
the problem becomes:
$$
\underset{\alpha_m^l}{\operatorname{argmin}} \left \| \vec{v}_d(\lambda, \phi) - \vec{u}_g(\lambda, \phi) - \alpha(\lambda, \phi) \vec{\tau}(\lambda, \phi) \right \|_2^2
$$,
where $\alpha(\lambda, \phi)$ is represented by a complex spherical harmonic decomposition:
$$
\alpha(\lambda, \phi) = \sum_l \sum_m \alpha_m^l Y_m^l(\lambda, \phi)
$$

This is a complex-valued least-squares problem, linear in the complex parameters $\alpha_m^l$.

The original $\beta$ and $\theta$ parameters of the empirical model of Rio *et al.* are recovered as:
$$
\beta = |\alpha| \\
\theta = arg(\alpha)
$$
which automatically satisfy the original constraints.

## 4. Computational challenge and optimization strategy

At the global scale, 10 years of data correspond to approximately about 20 million observations.
A direct least-squares solution would require materializing a dense design matrix of size 20M $\times$ 1k, which is prohibitive in memory ($\approx$ 500GB in complex128).

In anticipation of more physically and/or statistically elaborated (possibly nonlinear) models, we therefore recast the problem as a minibatch optimization problem, rather than relying on spatial binning.

Defining:
$$
\vec{x_i} = (\lambda_i, \phi_i, \vec{\tau}(\lambda_i, \phi_i)) \\
\vec{y_i} = \vec{v}_d(\lambda_i, \phi_i) - \vec{u}_g(\lambda_i, \phi_i)
$$
and the forward model:
$$
f_\alpha(\vec{x}) = \vec{\tau}(\lambda_i, \phi_i) \sum_l \sum_m \alpha_m^l Y_m^l(\lambda_i, \phi_i)
$$
we minimize the empirical risk:
$$
\mathcal{L}(\alpha) = \frac{1}{N} \sum_{i=1}^N \frac{1}{2} \left \| \vec{y_i} - f_\alpha(\vec{x_i}) \right \|^2_2
$$
The minimization is performed using gradient-based optimizers.

# Implementation

In [None]:
total_points = gdp_ds.points.size

x_full = (
    jnp.asarray(gdp_ds.lat.values), 
    jnp.asarray(gdp_ds.lon.values),
    jnp.asarray((gdp_ds.eastward_stress + 1j * gdp_ds.northward_stress).values)
)

u_drifters = jnp.asarray(gdp_ds.ve + 1j * gdp_ds.vn, dtype=jnp.complex128)
u_geos = jnp.asarray(gdp_ds.ugos + 1j * gdp_ds.vgos, dtype=jnp.complex128)

y_full = u_drifters - u_geos

## 2. Spherical harmonics decomposition

In [None]:
# Spherical harmonics setup

lmax = 8
n_coeffs = (lmax + 1) ** 2
idx = jnp.arange(n_coeffs)

L = jnp.repeat(jnp.arange(lmax + 1), 2 * jnp.arange(lmax + 1) + 1)
M = idx - L * (L + 1)

def get_sph_harm_bases(lat, lon):
    lat = jnp.deg2rad(lat)
    lon = jnp.deg2rad(lon)
    colat = jnp.pi / 2 - lat
    phi = lon % (2 * jnp.pi)

    legendre = gen_associated_legendre(lmax, colat, is_normalized=True)

    Y = jax.vmap(
        lambda l, m: sph_harm_y(l, m, colat, phi, n_max=lmax, legendre=legendre)
    )(L, M).T

    return Y

In [None]:
# Loss function: mean squared error + regularization

batch_size = 20_000_000

# background values ("climatology" from Rio et al. 2014)
beta_bg = 0.25
theta_bg = jnp.deg2rad(48.0)
alpha_bg = beta_bg * jnp.exp(1j * theta_bg)


def residual_fun(alpha_coeffs, lat_batch, lon_batch, tau_batch, y_batch):
    Y_batch = get_sph_harm_bases(lat_batch, lon_batch)

    alpha_batch = Y_batch @ alpha_coeffs
    alpha_bg_batch = jnp.where(lat_batch > 0, alpha_bg.conj(), alpha_bg)
    alpha_batch += alpha_bg_batch

    pred_batch = alpha_batch * tau_batch

    residual_batch = jnp.abs(y_batch - pred_batch) ** 2

    return residual_batch


lambda_reg = 1e-4
deg_norm = 2 * L + 1
reg_weight = L * (L + 1) / deg_norm


def loss_fun(alpha_coeffs, args):
    (lat_batch, lon_batch, tau_batch), y_batch = args

    residuals = residual_fun(alpha_coeffs, lat_batch, lon_batch, tau_batch, y_batch)
    
    # latitude-weighted mean squared error
    weight = jnp.cos(jnp.deg2rad(lat_batch))
    loss = jnp.nansum(residuals * weight) / jnp.nansum(weight)

    # degree-dependent regularization
    reg = lambda_reg * jnp.nansum(reg_weight * jnp.abs(alpha_coeffs) ** 2)
    loss += reg

    return loss


val_grad_loss_fun = jax.jit(jax.value_and_grad(loss_fun))

In [None]:
# Do optimization

key = jrd.key(0)

key, subkey = jrd.split(key)
beta_init = 1 + jrd.normal(subkey, shape=(n_coeffs,))
key, subkey = jrd.split(key)
theta_init = jrd.uniform(subkey, shape=(n_coeffs,), minval=-jnp.pi, maxval=jnp.pi)
alpha_init = beta_init * jnp.exp(1j * theta_init)

alpha_init = jnp.full((n_coeffs,), 1 + 1j * 0)

solver = optax.lbfgs()

n_epochs = 1000
opt_state = solver.init(alpha_init)
alpha_coeffs = alpha_init
best_coeffs = alpha_init
best_loss = jnp.inf
not_improved_epochs = 0

for epoch in range(n_epochs):
    # create random batches
    key, subkey = jrd.split(key)
    perm = jrd.permutation(subkey, total_points)
    n_batches = total_points // batch_size
    for i in range(0, n_batches):
        f = lambda x: loss_fun(x, ((lat_batch, lon_batch, tau_batch), y_batch))
        value_and_grad = optax.value_and_grad_from_state(f)

        batch_idx = perm[i * batch_size : (i + 1) * batch_size]
        lat_batch = x_full[0][perm]
        lon_batch = x_full[1][perm]
        tau_batch = x_full[2][perm]
        y_batch = y_full[perm]

        args = (lat_batch, lon_batch, tau_batch), y_batch

        loss_value, grads = value_and_grad(alpha_coeffs, state=opt_state)

        updates, opt_state = solver.update(
            grads, opt_state, alpha_coeffs, value=loss_value, grad=grads, value_fn=f
        )
        alpha_coeffs = optax.apply_updates(alpha_coeffs, updates)
    
    val_idx = perm[n_batches * batch_size:]  # use remaining points for validation
    lat_val = x_full[0][val_idx]
    lon_val = x_full[1][val_idx]
    tau_val = x_full[2][val_idx]
    y_val = y_full[val_idx]
    loss_value = loss_fun(alpha_coeffs, ((lat_val, lon_val, tau_val), y_val))

    if loss_value < best_loss:
        best_loss = loss_value
        best_coeffs = alpha_coeffs
        not_improved_epochs = 0
    else:
        not_improved_epochs += 1

    if (epoch + 1) % 10 == 0 or epoch == n_epochs - 1 or epoch == 0: 
        print(f"Epoch {epoch+1}/{n_epochs}, Validation loss: {loss_value:.6e}")

    if not_improved_epochs >= 10:
        print("Early stopping: no improvement in validation loss for 5 consecutive epochs.")
        break

In [None]:
alphas_fit = best_coeffs
betas_fit = jnp.abs(alphas_fit)
thetas_fit = jnp.angle(alphas_fit, deg=True)

In [None]:
alpha_rio2014_da = xr.open_zarr("data/rio_2014/alpha.zarr")["__xarray_dataarray_variable__"]  # for comparison

beta_rio2014_da = xr.DataArray(
    jnp.abs(alpha_rio2014_da.values),
    coords=alpha_rio2014_da.coords,
    dims=alpha_rio2014_da.dims,
)
theta_rio2014_da = xr.DataArray(
    jnp.angle(alpha_rio2014_da.values, deg=True),
    coords=alpha_rio2014_da.coords,
    dims=alpha_rio2014_da.dims,
)

In [None]:
global_lat = alpha_rio2014_da.lat.values
global_lon = alpha_rio2014_da.lon.values

global_lat2d, global_lon2d = jnp.meshgrid(global_lat, global_lon, indexing="ij")

In [None]:
global_lat_flat = global_lat2d.ravel()
global_lon_flat = global_lon2d.ravel()

global_Y = get_sph_harm_bases(global_lat_flat, global_lon_flat)

In [None]:
global_alpha = global_Y @ alphas_fit

global_alpha = global_alpha.reshape(alpha_rio2014_da.shape)

global_alpha += jnp.where(global_lat2d > 0, alpha_bg.conj(), alpha_bg)

global_beta = jnp.abs(global_alpha)
global_theta = jnp.angle(global_alpha, deg=True)

In [None]:
beta_fit_da = xr.DataArray(global_beta, coords=alpha_rio2014_da.coords, dims=alpha_rio2014_da.dims)
theta_fit_da = xr.DataArray(global_theta, coords=alpha_rio2014_da.coords, dims=alpha_rio2014_da.dims)

In [None]:
beta_fit_m_da = beta_fit_da * ~np.isnan(beta_rio2014_da)
theta_fit_m_da = theta_fit_da * ~np.isnan(theta_rio2014_da)

In [None]:
display(Math(
    r"\beta_e = {:.2f} \,\text{{m}}^2\text{{s}}/\text{{kg}}, \, \theta_e = {:.2f} \,\degree".format(
        beta_fit_m_da.mean().item(), np.abs(theta_fit_m_da).mean().item()
    )
))

In [None]:
fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(
    3, 2, figsize=(30, 20), subplot_kw={"projection": ccrs.PlateCarree()}
)

beta_rio2014_da.plot(ax=ax1, cmap=cmo.amp, vmin=0, vmax=3)
ax1.coastlines()

theta_rio2014_da.plot(ax=ax2, cmap=cmo.balance, vmin=-90, vmax=90)
ax2.coastlines()

beta_fit_m_da.plot(ax=ax3, cmap=cmo.amp, vmin=0, vmax=3)
ax3.coastlines()

theta_fit_m_da.plot(ax=ax4, cmap=cmo.balance, vmin=-90, vmax=90)
ax4.coastlines()

beta_fit_da.plot(ax=ax5, cmap=cmo.amp, vmin=0, vmax=3)
ax5.coastlines()

theta_fit_da.plot(ax=ax6, cmap=cmo.balance, vmin=-90, vmax=90)
ax6.coastlines()

fig.tight_layout()
plt.show()