In [None]:
import equinox as eqx
import h5py
import jax
import jax.numpy as jnp
import numpy as np

from pastax.gridded import Gridded
from pastax.simulator import DeterministicSimulator
from pastax.trajectory import Location

from src.dynamics.linear_deterministic import LinearDeterministic

In [2]:
jax.config.update("jax_enable_x64", True)

In [None]:
f = h5py.File("/summer/meom/workdir/bertrava/gdp-uc_2010-01-01_2011-01-01.hdf5", "r")

In [None]:
traj_len = f["gdp"][0][0].size
uc_shape = f["uc"][0][0].shape

In [None]:
f_out = h5py.File("/summer/meom/workdir/bertrava/gdp-uc_idealized_2010-01-01_2011-01-01.hdf5", "w")

In [6]:
n_samples = 10

In [7]:
gdp_ds = f_out.create_dataset(
    "gdp",
    (n_samples,),
    dtype=np.dtype(
        [
            ("lat", "f4", (traj_len,)), 
            ("lon", "f4", (traj_len,)), 
            ("time", "i4", (traj_len,)), 
            ("id", "i4")
        ]
    ),
    compression="lzf"
)

In [8]:
def create_field_dataset(f_, name, time_len, lat_len, lon_len):
    return f_.create_dataset(
        name,
        (n_samples,),
        chunks=(1,),
        dtype=np.dtype(
            [
                ("u", "f4", ( time_len, lat_len, lon_len)), 
                ("v", "f4", (time_len, lat_len, lon_len)), 
                ("time", "i4", (time_len)), 
                ("lat", "f4", (lat_len,)), 
                ("lon", "f4", (lon_len,)), 
            ]
        ),
        compression="lzf"
    )

In [None]:
uc_ds = create_field_dataset(f_out, "uc", *uc_shape)

In [10]:
uv_max = .15
drag_coef, wave_coef = np.random.uniform(.01, .05, 2)
drag_coef, wave_coef

(np.float64(0.028736656374492194), np.float64(0.019996845764446054))

In [11]:
simulator = DeterministicSimulator()

integration_horizon = 5  # days
integration_dt = 60*30  # seconds
n_steps = int(integration_horizon * 24 * 60 * 60 // integration_dt)

In [12]:
@eqx.filter_jit
def to_gridded(u, v, time, lat, lon):
    return Gridded.from_array(
        {"u": u, "v": v},
        time=time,
        latitude=lat,
        longitude=lon
    )

In [13]:
for i in range(n_samples):
    print(f"Sample {i+1}/{n_samples}")
    
    x0 = Location(np.asarray((f["gdp"][i][0][0], f["gdp"][i][1][0])))
    ts = jnp.asarray(f["gdp"][i][2], dtype=float)

    uc_u, uc_v = np.random.normal(.1, .001, 2)
    uh_u, uh_v = (
        np.random.uniform(0, (uv_max - uc_u) / wave_coef, 1), 
        np.random.uniform(0, (uv_max - uc_v) / wave_coef, 1)
    )
    uw_u, uw_v = (uv_max - uc_u - uh_u * wave_coef) / drag_coef, (uv_max - uc_v - uh_v * wave_coef) / drag_coef

    uc_u = np.full(uc_shape, uc_u)
    uc_v = np.full(uc_shape, uc_v)
    uw_u = np.full(uw_shape, uw_u)
    uw_v = np.full(uw_shape, uw_v)
    uh_u = np.full(uh_shape, uh_u)
    uh_v = np.full(uh_shape, uh_v)

    uc = to_gridded(
        uc_u, uc_v, f["uc"][i][2], f["uc"][i][3], f["uc"][i][4]
    )
    uw = to_gridded(
        uw_u, uw_v, f["uw"][i][2], f["uw"][i][3], f["uw"][i][4]
    )
    uh = to_gridded(
        uh_u, uh_v, f["uh"][i][2], f["uh"][i][3], f["uh"][i][4]
    )

    dynamics = LinearDeterministic.from_coefficients(drag_coef, wave_coef)

    dt0, saveat, stepsize_controller, adjoint, n_steps, _ = simulator.get_diffeqsolve_best_args(
        ts, integration_dt, n_steps=n_steps, constant_step_size=True, save_at_steps=False, ad_mode="forward"
    )

    traj = simulator(
        dynamics=dynamics, args=(uc, uw, uh), x0=x0, ts=ts, 
        dt0=dt0, saveat=saveat, stepsize_controller=stepsize_controller, adjoint=adjoint, max_steps=n_steps
    )

    gdp_ds[i] = (traj.latitudes.value, traj.longitudes.value, traj.times.value, 0)
    uc_ds[i] = (
        uc_u, uc_v, f["uc"][i][2], f["uc"][i][3], f["uc"][i][4]
    )
    uw_ds[i] = (
        uw_u, uw_v, f["uw"][i][2], f["uw"][i][3], f["uw"][i][4]
    )
    uh_ds[i] = (
        uh_u, uh_v, f["uh"][i][2], f["uh"][i][3], f["uh"][i][4]
    )

Sample 1/10


Sample 2/10
Sample 3/10
Sample 4/10
Sample 5/10
Sample 6/10
Sample 7/10
Sample 8/10
Sample 9/10
Sample 10/10
