In [None]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np

In [None]:
Nsq = np.load('Nsq.npy', allow_pickle=True)
depth = Nsq[1, :]
Nsq = Nsq[0, :]

In [None]:
from  multimodemodel import StaggeredGrid

nmodes = 25

c_grid = StaggeredGrid.regular_lat_lon_c_grid(
    lon_start=-50.0,
    lon_end=0.0,
    lat_start=-10.0,
    lat_end=10.0,
    nx=50 * 4 + 1,
    ny=20 * 4 + 1,
    z = np.arange(nmodes)
)

In [None]:
from multimodemodel import MultimodeParameters, f_on_sphere
multimode_params = MultimodeParameters(z=depth, Nsq=Nsq, nmodes=nmodes, coriolis_func=f_on_sphere(omega=7.272205e-05), on_grid=c_grid, no_slip=True)

In [None]:
ds = multimode_params.as_dataset

In [None]:
a = np.gradient(ds.dpsi_dz, depth, axis=0)

In [None]:
a.shape

In [None]:
 b = ds.Nsq[0] * ds.psi[0,:] / ds.c**2

In [None]:
ds.psi.plot(x='nmode', y='depth')

In [None]:
ds.dpsi_dz.plot(x='nmode', y='depth')

In [None]:
import math as m
def plot_tensors(tensor: np.ndarray):
    fig, axs = plt.subplots(2, 2, figsize=(15, 10), tight_layout=True)
    nmodes = tensor.shape[0]
    ks = []
    for i in range(4):
        ax =np.ravel(axs)[i]
        k = (i+1) * m.floor(nmodes / 4) - 1
        plot = ax.pcolormesh(tensor[k, :, :])
        fig.colorbar(plot, ax=ax)
        ax.set_title('Modenumber ' + str(k))

In [None]:
plot_tensors(multimode_params.Q)

In [None]:
plot_tensors(multimode_params.R)

In [None]:
plot_tensors(multimode_params.S)

In [None]:
plot_tensors(multimode_params.T)

In [None]:
H =abs(depth[0] - depth[-1])
A = 1.33e-7 / H
gamma = A / ds.c.values**2
multimode_params.gamma = gamma

In [None]:
def tau(x, y):
    """"Wind field according to Mccreary (1980)."""
    delta_x = abs(x[0, 0] - x[0, -1]) / 2
    delta_y = abs(y[0, 0] - y[-1, 0]) / 2

    wind_x = np.cos(np.pi * x / delta_x)
    wind_x[abs(x) > delta_x / 2] = 0

    wind_y = (1 + y**2 / delta_y**2) * np.exp( - y**2 / delta_y**2)

    return -5e-6 * wind_x * wind_y

In [None]:
# from multimodemodel import Parameters, f_on_sphere
# params = []
# for i in range(multimode_params.nmodes):
#     params.append(
#         Parameters(
#             coriolis_func=f_on_sphere(omega=7.272205e-05),
#             on_grid=c_grid,
#             H=np.array([ds.H.values[i]]),
#             gamma = np.array([gamma[i]]),
#         )
#     )

In [None]:
from multimodemodel import (
    State, Variable,
 )

tau_x = np.empty(c_grid.u.shape)
for k in range(nmodes):
    tau_x[k, :, :] = - 0.05 * ds.psi.values[0, k]

tau_x *= c_grid.u.mask

def zonal_wind(state, params):
    return State(u=Variable(tau_x /  params.rho_0 / H, c_grid.u, np.datetime64("NaT")));

In [None]:
import functools as ft
import operator as op
from multimodemodel import (
    pressure_gradient_i, pressure_gradient_j,
    coriolis_i, coriolis_j,
    divergence_i, divergence_j,
    laplacian_mixing_u, laplacian_mixing_v,
    linear_damping_eta, linear_damping_u,
    linear_damping_v, advection_density,
    advection_momentum_u, advection_momentum_v,
)

terms = [
    pressure_gradient_i, pressure_gradient_j,
    coriolis_i, coriolis_j,
    divergence_i, divergence_j,
    laplacian_mixing_u, laplacian_mixing_v,
    linear_damping_u, linear_damping_v,
    linear_damping_eta, 
    advection_density,
    advection_momentum_u, advection_momentum_v,
    zonal_wind
]

def rhs(state, params):
    w = (divergence_j(state, params) + divergence_i(state, params)).eta.safe_data
    state.set_diagnostic_variable(w=Variable(w, c_grid.eta, state.eta.time))
    return ft.reduce(op.add, (term(state, params) for term in terms))

In [None]:
def save_as_Dataset(state: State, params: MultimodeParameters):
    ds = state.variables["u"].as_dataarray.to_dataset(name='u_tilde')
    ds['v_tilde'] = state.variables["v"].as_dataarray
    ds['h_tilde'] = state.variables["eta"].as_dataarray
    x = (["j", "i"], (state.u.grid.x + state.v.grid.x) / 2)
    y = (["j", "i"], (state.u.grid.y + state.v.grid.y) / 2)
    ds.assign_coords({"x": x, "y": y})
    return ds

In [None]:
from multimodemodel import integrate, adams_bashforth3

time = 10 * 24 * 3600.  # 1 year
step = c_grid.u.dx.min() / ds.c.values.max() / 10.
t0 = np.datetime64("2000-01-01")

In [None]:
def run(params, step, time):
    model_run = integrate(
        State(
            u=Variable(None, c_grid.u, t0),
            v=Variable(None, c_grid.v, t0),
            eta=Variable(None, c_grid.eta, t0),
            q=Variable(None, c_grid.q, t0)
        ),
        params,
        RHS=rhs,
        scheme=adams_bashforth3,
        step=step,
        time=time,
    )
    
    Nt = time // step

    output = []

    tolerance = 10.
    for i, next_state in enumerate(model_run):
        if i % (Nt // 5) == 0:
            output.append(save_as_Dataset(next_state, params))
        if np.nanmax(abs(next_state.variables["u"].safe_data)) > tolerance:
            output.append(save_as_Dataset(next_state, params))
            tolerance += 1.
        if tolerance > 20.:
            return xr.combine_by_coords(output)
    
    
    return xr.combine_by_coords(output)

In [None]:
# from multiprocessing import Pool

# pool = Pool()
# out = pool.map(run, params)

In [None]:

out = run(multimode_params, step=step, time=time)

In [None]:
out = out.rename({'z':'nmode'})

In [None]:
out['u'] = xr.dot(ds.psi, out.u_tilde)
out['v'] = xr.dot(ds.psi, out.v_tilde)
out['h'] = xr.dot(ds.psi, out.h_tilde)

In [None]:
from matplotlib import colors
out.v_tilde.isel(i = 100, time =5).plot.pcolormesh(x='nmode', y='y', cmap='RdBu_r', figsize=(20,10))

In [None]:
out.u.isel(j=20,i = slice(0,20), depth = slice(50,200), time=3).plot.pcolormesh(x='x', y='depth', cmap='RdBu_r', figsize=(20,10), norm=);

In [None]:
test_state = State(
    u=Variable(out.u_tilde.isel(time=5).values, c_grid.u, t0),
    v=Variable(out.u_tilde.isel(time=5).values, c_grid.v, t0),
    eta=Variable(out.u_tilde.isel(time=5).values, c_grid.eta, t0),
    q=Variable(out.u_tilde.isel(time=5).values, c_grid.q, t0)
)
w = (divergence_j(test_state, multimode_params) + divergence_i(test_state, multimode_params)).eta.safe_data
test_state.set_diagnostic_variable(w=Variable(w, c_grid.eta, test_state.eta.time))

zero_state = State(
    u=Variable(np.zeros(c_grid.u.shape), c_grid.u, t0),
    v=Variable(np.zeros(c_grid.u.shape), c_grid.v, t0),
    eta=Variable(np.zeros(c_grid.u.shape), c_grid.eta, t0),
    q=Variable(np.zeros(c_grid.u.shape), c_grid.q, t0)
)

In [None]:
term = 0
for func in terms:
    next_state = zero_state +  func(test_state, multimode_params)
    print(term)
    print(np.nanmax(abs(next_state.variables["u"].safe_data)))
    print(np.nanmax(abs(next_state.variables["v"].safe_data)))
    print(np.nanmax(abs(next_state.variables["eta"].safe_data)))
    term += 1