In [None]:
import cartopy.crs as ccrs
from cmocean import cm 
from dino import Experiment
from matplotlib import colors
from matplotlib import pyplot as plt
import numpy as np
import xarray as xr
import cftime as cft
import xnemogcm as xn
import xgcm

In [None]:
path   = "/data/dkamm/nemo_output/DINO/"
dino_exp_sfx = Experiment(path, 'HigherRidge')

In [None]:
path   = "/data/dkamm/nemo_output/DINO/"
dino_exp_emp = Experiment(path, 'HigherRidgeEmP')

In [None]:
fig, axs = plt.subplots(1,1, figsize=(12,6))
dino_exp_sfx.get_ACC().plot(color='midnightblue', label='Salt-flux from S_star')
dino_exp_emp.get_ACC().plot(color='darkred', label='E-P')
axs.set_ylabel('ACC transport [Sv]')
plt.title('')
plt.xlabel('time [years]')
plt.legend()
plt.grid()

In [None]:
path   = "/data/dkamm/nemo_output/DINO/"
dino_exp = Experiment(path, 'HigherRidgeEmP/restart40')

In [None]:
fig, axs = plt.subplots(1,1, figsize=(12,6))
dino_exp.get_ACC().plot(color='darkred')
axs.set_ylabel('ACC transport [Sv]')
plt.title('')
plt.xlabel('time [years]')

plt.grid()

In [None]:
fig, axs = plt.subplots(1,2,figsize=(15,6))
a = dino_exp_sfx.get_rho(z=2000.).isel(t_y=-1, x_c=30, y_c=slice(1,-1), z_c=slice(0,-1))#.plot.contourf(x='gphit', y='gdept_0', cmap=cm.dense_r, levels=36, add_colorbar=True, ax=axs[0])

b = dino_exp_emp.get_rho(z=2000.).isel(t_y=-1, x_c=30, y_c=slice(1,-1), z_c=slice(0,-1))#.plot.contourf(x='gphit', y='gdept_0', cmap=cm.dense_r, levels=36, add_colorbar=True, ax=axs[1])
c = a.plot.contourf(x='gphit', y='gdept_0', cmap=cm.dense_r, levels=36, add_colorbar=False, ax=axs[0])
cbar1 = fig.colorbar(c, ax=axs[:2], label=r'$\rho$ [kg/m^3]')
b.plot.contourf(x='gphit', y='gdept_0', cmap=cm.dense_r, levels=36, add_colorbar=False, ax=axs[1], vmin=cbar1.vmin, vmax=cbar1.vmax)

#cbar1 = fig.colorbar(a, ax=axs, label=r'$\rho$ [kg/m^3]')
plt.xlabel('latitude [°N]')
axs[0].invert_yaxis()
axs[1].invert_yaxis()
#axs[2].invert_yaxis()
axs[0].set_ylabel(r'$\sigma_{0}$ [ $kg$ / $m^3$ - 1000 ]')
axs[1].set_ylabel(r'$\sigma_{0}$ [ $kg$ / $m^3$ - 1000 ]')
axs[0].set_title('Salt-restoring')
axs[1].set_title('E-P')
#plt.tight_layout()

In [None]:
mld_sep = dino_exp.data.mldr10_1.where(((dino_exp.data['t_m.year'] >= 100) & (dino_exp.data['t_m.month'] == 9)), drop=True).mean('t_m')
mld_mar = dino_exp.data.mldr10_1.where(((dino_exp.data['t_m.year'] >= 100) & (dino_exp.data['t_m.month'] == 3)), drop=True).mean('t_m')

mld_sep = mld_sep.assign_coords({'x_globe': mld_sep.glamt - 30})
mld_mar = mld_mar.assign_coords({'x_globe': mld_mar.glamt - 30})

In [None]:
plt.figure(figsize=(10,10))
a = mld_mar.plot.contourf(
    y='gphit',
    x='x_globe',
    cmap=cm.deep_r,
    levels=30,
    subplot_kws=dict(projection=ccrs.Robinson()),

    transform=ccrs.PlateCarree(),
    #add_colorbar=False,
)
a.axes.gridlines(
    draw_labels=["x", "y", "geo"],
    ylocs=[-70, -45, -20, 0, 20, 45, 70],
    xlocs=[0]
)
plt.title('')
plt.grid('m')
#plt.tight_layout()

In [None]:
from matplotlib.colors import LogNorm

soce = (dino_exp.data.isel(t_y=-1).where(dino_exp.domain.tmask == 1.)).soce.values.ravel()
toce = (dino_exp.data.isel(t_y=-1).where(dino_exp.domain.tmask == 1.)).toce.values.ravel()
vol  = (dino_exp.domain.e1t * dino_exp.domain.e2t * dino_exp.data.isel(t_y=-1).e3t).values.ravel()

smin, smax = 34.5, 37.5
tmin, tmax = -2, 29

s = np.linspace(smin, smax, 100)
t = np.linspace(tmin, tmax, 100)

Sg, Tg = np.meshgrid(s,t)

nml = dino_exp.namelist['nameos']

rho = (
    - nml['rn_a0'] * (1. + 0.5 * nml['rn_lambda1'] * ( Tg - 10.)) * ( Tg - 10.) 
    + nml['rn_b0'] * (1. - 0.5 * nml['rn_lambda2'] * ( Sg - 35.)) * ( Sg - 35.) 
    - nml['rn_nu'] * ( t - 10.) * ( Sg - 35.)
) + 1026

fig, ax = plt.subplots()
hb = ax.hexbin(soce, toce,
           C=vol, reduce_C_function=np.sum,
           extent=(34.5,37.5,-2,29), gridsize=50, bins='log',
           cmap=cm.matter)
plt.colorbar(hb)
cp = ax.contour(Sg,Tg, rho, levels=np.arange(1021, 1029, 0.5), linestyles='dashed', colors='black')
cl=plt.clabel(cp,fontsize=10,inline=True,fmt="%.1f")
ax.set_ylabel(r'Temperature ($^\circ$C)')
ax.set_xlabel("Salinity (g / kg)")

In [None]:
moc =  dino_exp.get_MOC(dino_exp_emp.data.voce + dino_exp_emp.data.voce_eiv, z=2000)
moc_sfx =  dino_exp_sfx.get_MOC(dino_exp_sfx.data.voce + dino_exp_sfx.data.voce_eiv, z=2000)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(15,6), sharey=True)
a = (-moc).plot.contourf(x='y_f', y='rho', cmap='RdBu_r', levels=32, add_colorbar=False, ax=axs[0])
cbar1 = fig.colorbar(a, ax=axs, label=r'$\psi$ [Sv]')
b = (-moc_sfx).plot.contourf(x='y_f', y='rho', cmap='RdBu_r', levels=32, add_colorbar=False, ax=axs[1], vmin=cbar1.vmin, vmax=cbar1.vmax)
#c = (-moc + moc_gh).plot.contourf(x='y_f', y='rho', cmap='RdBu_r', levels=36, ax=axs[2])

plt.xlabel('latitude [°N]')
axs[0].invert_yaxis()
plt.ylabel(r'$\sigma_{0}$ [ $kg$ / $m^3$ - 1000 ]')
plt.title('')

In [None]:
bts = dino_exp.get_BTS()
bts = bts.assign_coords({'x_globe': bts.glamf - 30})

In [None]:
plt.figure(figsize=(10,10))
a = bts.isel(t_y=-1, y_f=slice(0,100)).plot.contourf(
    y='gphif',
    x='x_globe',
    cmap=cm.balance,
    levels=30,
    subplot_kws=dict(projection=ccrs.Robinson()),

    transform=ccrs.PlateCarree(),
    #add_colorbar=False,
)
a.axes.gridlines(
    draw_labels=["x", "y", "geo"],
    ylocs=[-70, -45, -20, 0, 20, 45, 70],
    xlocs=[0]
)
plt.title('')
plt.grid('m')
plt.tight_layout()

In [None]:
drdz = dino_exp_emp.grid.derivative(dino_exp_emp.data.rhop.where(dino_exp_emp.domain.tmask==1.0), 'Y').isel(z_c=slice(0,24), y_f=slice(13,46)).mean('x_c')
dtdz = dino_exp_emp.grid.derivative(dino_exp_emp.data.toce.where(dino_exp_emp.domain.tmask==1.0), 'Y').isel(z_c=slice(0,24), y_f=slice(13,46)).mean('x_c')
dsdz = dino_exp_emp.grid.derivative(dino_exp_emp.data.soce.where(dino_exp_emp.domain.tmask==1.0), 'Y').isel(z_c=slice(0,24), y_f=slice(13,46)).mean('x_c')

In [None]:
fig, axs = plt.subplots(1,3,figsize=(15,6))
(drdz.isel(t_y=-1) - drdz.isel(t_y=0)).plot(yincrease=False, ax=axs[0])
(dtdz.isel(t_y=-1) - dtdz.isel(t_y=0)).plot(yincrease=False, ax=axs[1])
(dsdz.isel(t_y=-1) - dsdz.isel(t_y=0)).plot(yincrease=False, ax=axs[2])

In [None]:
import scipy.sparse as sparse
import scipy.sparse.linalg as la

In [None]:
#@staticmethod
def _get_dynmodes(Nsq, e3w, e3t, nmodes=2):
    """
    Calculate the 1st nmodes ocean dynamic vertical modes.
    Based on
    http://woodshole.er.usgs.gov/operations/sea-mat/klinck-html/dynmodes.html
    by John Klinck, 1999.
    """
    nmodes = 5#min((nmodes, len(Nsq) - 2))
    # 2nd derivative matrix plus boundary conditions
    Ndz     = (Nsq * e3w)
    e3t     = e3t
    #Ndz_m1  = np.roll(Ndz, -1)
    #e3t_p1  = np.roll(e3t, 1)
    d0  = np.r_[1. / Ndz[1] / e3t[0],
               (1. / Ndz[2:-1] + 1. / Ndz[1:-2]) / e3t[1:-2],
               1. / Ndz[-2] / e3t[-2]]
    d1  = np.r_[0., -1. / Ndz[1:-1] / e3t[1:-1]]
    dm1 = np.r_[-1. / Ndz[1:-1] / e3t[0:-2], 0.]
    diags = np.vstack((d0, d1, dm1))
    d2dz2 = sparse.dia_matrix((diags, (0, 1, -1)), shape=(len(Nsq)-1, len(Nsq)-1))
    # Solve generalized eigenvalue problem for eigenvalues and vertical
    # Horizontal velocity modes
    eigenvalues, modes = la.eigs(d2dz2, k=nmodes+1, which='SM')
    mask = (eigenvalues.imag == 0) & (eigenvalues >= 1e-10)
    eigenvalues = eigenvalues[mask]
    # Sort eigenvalues and modes and truncate to number of modes requests
    index = np.argsort(eigenvalues)
    eigenvalues = eigenvalues[index[:nmodes]].real
    # Modal speeds
    ce = 1 / np.sqrt(eigenvalues)
    return(ce)

In [None]:
Nsq = dino_exp_emp.get_N_squared().where(dino_exp_emp.domain.tmask==1.)
e3w = dino_exp_emp.data.e3w.where(dino_exp_emp.domain.tmask==1.)
e3t = dino_exp_emp.data.e3t.where(dino_exp_emp.domain.tmask==1.)

In [None]:
isel = {'y_c' : 100, 'x_c' : 30, 't_y' : -1}
test = _get_dynmodes(Nsq=Nsq.isel(isel), e3w=e3w.isel(isel), e3t=e3w.isel(isel))

In [None]:
func = xr.apply_ufunc(_get_dynmodes, Nsq, e3w, e3w, dask='parallelized')


In [None]:
import scipy.sparse as sp
import scipy.sparse.linalg as la
from scipy.linalg import eig

def compute_vmodes_1D(Nsqr, dzc=None, dzf=None, nmodes=2): 
    """
    Compute vertical modes from stratification. Assume grid is sorted downoward (first point at surface, last point at bottom) and depth is algebraic (i.e. negative)
    Take either vertical grid metrics (spacing) or levels as inputs. 
    Need 2 staggered grid (center and left or outer), with Nsqr specified on left/outer grid
    No normalization. Pressure mode is positive at the surface.

    Parameters:
    ___________
    N2f: (N,) ndarray
        Brunt-Vaisala frequency at cell left points
    dzc: (N) ndarray, optional
        vertical grid spacing at cell centers. Either dzc, dzf or zc, zf must be passed
    dzf: (N) ndarray
        vertical grid spacing at cell left points
    nmodes: int, optional
        number of baroclinic modes to compute (barotropic mode will be added)

    Returns:
    ________
    c: (nmodes) ndarray
        eigenvalues (pseudo phase speed, c=1/sqrt(k))
    !! Currently not returning the modes since they are not needed
    phi: (N,nmodes) ndarray
        p-like modes at cell centers
    phiw: (N,nmodes) ndarray
        w-like modes at cell interfaces. phiw' = phi

    Notes:
    ______
    The vertical modes are definied following the equation:
    .. math:: (\phi'/N^2)' + k^2\phi=0 
    with boundary condition :math:`\phi'=0` at the bottom and :math:`g\phi' + N^2\phi=0` at the surface (or :math:`\phi'=0` for a rigid lid condition). 
    Computation of the vertical modes is performed using second order finite difference with staggered grid

    """
    ### parameters:
    g = 9.80665

    ### deal with vertical grids
    Nz = Nsqr.size
    if dzc is not None and dzf is not None:
        dz_surf = .25*(dzc[0] + dzf[0]) ### this is approx for NEMO grid
        dzc, dzf = dzc, dzf
    else:
        raise ValueError("must specify grid increments dzc, dzf") 

    invg = np.ones(1)/g
    
    Nsqog = Nsqr[:1]*invg

    v12 =  np.stack([1./np.r_[dzc, np.ones(1),], -1./np.r_[np.ones(1), dzc]])
    Dw2p = sp.spdiags(v12,[0, 1],Nz,Nz,format="lil")
    ### vertical derivative matrix, p-to-w grids, targetting inner w points only
    v12 =  np.stack([1./np.r_[dzf[1:], np.ones(1)], -1./dzf])
    Dp2w = sp.spdiags(v12,[-1, 0],Nz,Nz,format="lil")
    
    ### second order diff matrix
    D2z = Dw2p*Dp2w
    Dp2w[0,0] = -Nsqog*(1-Nsqog*dz_surf) # surface boundary condition (free or rigid lid)
    ### formulation of the problem : -dz(dz(p)/N^2) = lambda * p
    A = - Dw2p * sp.diags(1./Nsqr) * Dp2w
    ### compute numerical solution
    ev,ef = la.eigs(A.tocsc(), k=nmodes+1, which='SM')

    #### select and arrange modes
    inds = np.isfinite(ev)
    ev, ef = ev[inds].real, ef[:,inds].real
    isort = np.argsort(ev)[:nmodes+1]
    ev, ef = ev[isort], ef[:,isort]
    # ef *= np.sign(ef[0,:])[None,:] # positive pressure at the surface
    # if first_ord:
    #     pmod, wmod = ef[:Nz,:], -ef[Nz:,:]
    # else:
    #     pmod = ef[:Nz,:]
    #     wmod = -(Dp2w * pmod) / (Nsqr[:,None] * ev[None,:])
    #     if not (free_surf and g>0):
    #         wmod[:,0] = 0.
    
    return 1./ev**.5 #, pmod, wmod

In [None]:
def get_vmodes(exp, nmodes=2):
    """ compute vertical modes
    Wrapper for calling `compute_vmodes` with DataArrays through apply_ufunc. 
    z levels must be in descending order (first element is at surface, last element is at bottom) with algebraic depth (i.e. negative)
    Normalization is performed here (int_z \phi^2 \dz = Hbot)
    
    Parameters:
    ___________
    ds: xarray.Dataset
        contains brunt-vaisala frequency and vertical grid information (levels of metrics, i.e. spacing)
    nmodes: int, optional
        number of vertical baroclinic modes (barotropic is added)
    
    Returns:
    ________
    xarray.DataSet: vertical modes (p and w) and eigenvalues
    !! (currently only eigenvalues)
    _________
    """
    Nsq = (exp.get_N_squared())
    res = xr.apply_ufunc(_get_dynmodes, 
                         Nsq.chunk({'z_f':-1}),
                         exp.data.e3w.chunk({'z_f':-1}),
                         exp.data.e3t.where(exp.domain.tmask==1.0).chunk({'z_c':-1}),
                         
                         input_core_dims=[['z_f'],['z_f'],['z_c']],
                         dask='parallelized', vectorize=True,
                         output_dtypes=[Nsq.dtype],
                         output_core_dims=[["mode"]],
                         dask_gufunc_kwargs={"output_sizes":{"mode":nmodes+1}}
                        )
    # res['mode'] = np.arange(nmodes+1)
    # # unstack variables
    # c = res.isel(s_stack=0)
    # phi = (res.isel(s_stack=slice(1,N+1))
    #        .rename('phi')
    #        .rename({'s_stack': zc})
    #        #.assign_coords(z_rho=zc)
    #       )
    # if "z_del" in kwargs:
    #     dzc = ds[kwargs["z_del"]["zc"]]
    # else:
    #     dzc = ds["e3t"] # use default value for NEMO    
    # norm_tg = dzc.where(ds.tmask).sum(zc)
    # norm = (phi**2*dzc).where(ds.tmask).sum(zc) 
    # phi /= (norm/norm_tg)**.5 # 1/H \int(phi^2 dz) = 1
    # phiw = (res.isel(s_stack=slice(N+1,2*N+1))
    #           .rename('phiw')
    #           .rename({'s_stack': zl})
    #         #  .assign_coords(z_w=zf)
    #          ) / (norm/norm_tg)**.5
    # norm = norm_tg # norm = int(phi^2 dz)
    # # merge data into a single dataset
    # dm = xr.merge([c.rename("c"), phi.rename("phi"), 
    #                phiw.rename("phiw"), norm.rename("norm")
    #              ])
    return res  ### hard-coded norm = H 

In [None]:
dino_exp.get_N_squared().isel(z_f=slice(1,-1), t_y=-1, x_c=30, y_c=120).plot(yscale='log')

In [None]:
test = compute_vmodes_1D(
    dino_exp.get_N_squared().isel(z_f=slice(1,-1), t_y=-1, x_c=30, y_c=100).values,
    dino_exp.data.e3t.where(dino_exp.domain.tmask==1.0).isel(z_c=slice(1,-1), t_y=-1, x_c=30, y_c=100).values,
    dino_exp.data.e3w.isel(z_f=slice(1,-1) ,t_y=-1, x_c=30, y_c=100).values
)

In [None]:
Nsq = dino_exp.get_N_squared().isel(t_y=-1)

In [None]:
Nsq = Nsq.where(Nsq >= 0).fillna(0.)

In [None]:
wmodes, pmodes, (tuple1, tuple2), bla, blub= np.load('vertical_structure_functions_dataset.npy', allow_pickle=True)

In [None]:
pmodes[1]

In [None]:
dino_exp_emp.data.e3t.where(dino_exp_emp.domain.tmask==1.0).isel(t_y=-1, x_c=30, y_c=100)

In [None]:
get_vmodes(dino_exp, nmodes=2).isel(x_c=30, y_c=100, t_y=-1).plot()

In [None]:
dino_exp.data

In [None]:
Nsq = dino_exp_emp.get_N_squared()

In [None]:
Nsq.isel(t_y=-1).min().values

In [None]:
Nsq