In [1]:
"""
This script runs the BVEX model in a standalone mode. Key parameters are specified
in namelist.py. One complete run is split into a number of relays, which are
comprised by sprints. One time slice is saved after one sprint (multiple time
steps), and one relay outputs one NetCDF file with multiple time slices.
"""

import time
import numpy.fft as nfft
from bvex import *
from functools import partial
import powerpax

In [2]:
q0, _, Kx, Ky, Del2, _, _, x, y = setup_ic_grid(shearFrac, randSeed)

In [3]:
qNow = q0 
tNow = 0.0

In [4]:
dt = 0.01
T = 10
interval = 250
t_max = 8300
n_steps = int(t_max/dt)

In [5]:
t_set = tNow + jnp.arange(n_steps) * dt

In [14]:
def step_fn(carry, tNow):
    new_carry, y = etdrk4(carry, tNow)
    return new_carry, new_carry

In [15]:
@partial(jax.jit, static_argnums=(1,2,3))
def do_sliced_scan(init_state, start, total_steps, interval):
    t_set = tNow + jnp.arange(total_steps) * dt
    
    final_state, trajectory = powerpax.sliced_scan(
        step_fn,
        init_state,
        t_set,
        length=total_steps,
        reverse=False,
        start=start,
        stop=total_steps-1,
        step=interval
    )
    return final_state, trajectory

In [23]:
final_state, history = do_sliced_scan(q0,
                                      start=int(1000/dt), 
                                      total_steps = n_steps,
                                      interval=interval
                                     )