In [1]:
import diffrax as dfx
import equinox as eqx
import h5py
import jax
import jax.numpy as jnp
import numpy as np

from pastax.gridded import Gridded
from pastax.simulator import DeterministicSimulator
from pastax.trajectory import Location

from src.dynamics.linear_deterministic import LinearDeterministic

In [2]:
jax.config.update("jax_enable_x64", True)

In [3]:
f = h5py.File("/summer/meom/workdir/bertrava/gdp-uc-uw-uh_2010-01-01_2011-01-01.hdf5", "r")

In [4]:
traj_len = f["gdp"][0][0].size
uc_shape = f["uc"][0][0].shape
uw_shape = f["uw"][0][0].shape
uh_shape = f["uh"][0][0].shape

In [5]:
f_out = h5py.File("/summer/meom/workdir/bertrava/gdp-uc-uw-uh_idealized_2010-01-01_2011-01-01.hdf5", "w")

In [6]:
n_samples = 1000

In [7]:
gdp_ds = f_out.create_dataset(
    "gdp",
    (n_samples,),
    dtype=np.dtype(
        [
            ("lat", "f4", (traj_len,)), 
            ("lon", "f4", (traj_len,)), 
            ("time", "i4", (traj_len,)), 
            ("id", "i4")
        ]
    ),
    compression="lzf"
)

In [8]:
def create_field_dataset(f_, name, time_len, lat_len, lon_len):
    return f_.create_dataset(
        name,
        (n_samples,),
        chunks=(1,),
        dtype=np.dtype(
            [
                ("u", "f4", ( time_len, lat_len, lon_len)), 
                ("v", "f4", (time_len, lat_len, lon_len)), 
                ("time", "i4", (time_len)), 
                ("lat", "f4", (lat_len,)), 
                ("lon", "f4", (lon_len,)), 
            ]
        ),
        compression="lzf"
    )

In [9]:
uc_ds = create_field_dataset(f_out, "uc", *uc_shape)
uw_ds = create_field_dataset(f_out, "uw", *uw_shape)
uh_ds = create_field_dataset(f_out, "uh", *uh_shape)

In [10]:
drag_coef, wave_coef = np.random.uniform(.01, .05, 2)
drag_coef, wave_coef

(np.float64(0.03404194535118102), np.float64(0.02507911476752432))

In [11]:
simulator = DeterministicSimulator()

integration_horizon = 5  # days
integration_dt = 60*30  # seconds
n_steps = int(integration_horizon * 24 * 60 * 60 // integration_dt)

dynamics = LinearDeterministic(drag_coef, wave_coef)

In [12]:
@eqx.filter_jit
def to_gridded(u, v, time, lat, lon):
    return Gridded.from_array(
        {"u": u, "v": v},
        time=time,
        latitude=lat,
        longitude=lon
    )

In [13]:
for i in range(n_samples):
    print(f"Sample {i+1}/{n_samples}")
    
    x0 = Location(np.asarray((f["gdp"][i][0][0], f["gdp"][i][1][0])))
    ts = jnp.asarray(f["gdp"][i][2], dtype=float)

    uc = to_gridded(
        f["uc"][i][0], f["uc"][i][1], f["uc"][i][2], f["uc"][i][3], f["uc"][i][4]
    )
    uw = to_gridded(
        f["uw"][i][0], f["uw"][i][1], f["uw"][i][2], f["uw"][i][3], f["uw"][i][4]
    )
    uh = to_gridded(
        f["uh"][i][0], f["uh"][i][1], f["uh"][i][2], f["uh"][i][3], f["uh"][i][4]
    )

    dt0, saveat, stepsize_controller, adjoint, n_steps, _ = simulator.get_diffeqsolve_best_args(
        ts, integration_dt, n_steps=n_steps, constant_step_size=True, save_at_steps=False, ad_mode="forward"
    )

    traj = simulator(
        dynamics=dynamics, args=(uc, uw, uh), x0=x0, ts=ts, solver=dfx.Tsit5(), 
        dt0=dt0, saveat=saveat, stepsize_controller=stepsize_controller, adjoint=adjoint, max_steps=n_steps
    )

    gdp_ds[i] = (traj.latitudes.value, traj.longitudes.value, traj.times.value, 0)
    uc_ds[i] = f["uc"][i]
    uw_ds[i] = f["uw"][i]
    uh_ds[i] = f["uh"][i]

Sample 1/1000
Sample 2/1000
Sample 3/1000
Sample 4/1000
Sample 5/1000
Sample 6/1000
Sample 7/1000
Sample 8/1000
Sample 9/1000
Sample 10/1000
Sample 11/1000
Sample 12/1000
Sample 13/1000
Sample 14/1000
Sample 15/1000
Sample 16/1000
Sample 17/1000
Sample 18/1000
Sample 19/1000
Sample 20/1000
Sample 21/1000
Sample 22/1000
Sample 23/1000
Sample 24/1000
Sample 25/1000
Sample 26/1000
Sample 27/1000
Sample 28/1000
Sample 29/1000
Sample 30/1000
Sample 31/1000
Sample 32/1000
Sample 33/1000
Sample 34/1000
Sample 35/1000
Sample 36/1000
Sample 37/1000
Sample 38/1000
Sample 39/1000
Sample 40/1000
Sample 41/1000
Sample 42/1000
Sample 43/1000
Sample 44/1000
Sample 45/1000
Sample 46/1000
Sample 47/1000
Sample 48/1000
Sample 49/1000
Sample 50/1000
Sample 51/1000
Sample 52/1000
Sample 53/1000
Sample 54/1000
Sample 55/1000
Sample 56/1000
Sample 57/1000
Sample 58/1000
Sample 59/1000
Sample 60/1000
Sample 61/1000
Sample 62/1000
Sample 63/1000
Sample 64/1000
Sample 65/1000
Sample 66/1000
Sample 67/1000
Samp