# Documentation

**Author:** Spencer Ressel

**Created:** June 14th, 2023

---

This code numerically solves the governing equations from Matsuno (1966). It was initially written by Daniel Lloveras as a project for the course ATM S 582.  The solver uses the pseudospectral method with leapfrog time differencing to solve the equations of motion.

---

# Imports

In [None]:
import os
os.chdir(f"/home/disk/eos7/sressel/research/thesis-work/python/numerical_solver/")
import sys
import numpy as np
import xarray as xr
from scipy import special
from numpy.fft import fft, ifft, fftfreq

import matplotlib.pyplot as plt
plt.rcParams.update({'font.size':22})
import matplotlib.ticker as mticker
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec
from matplotlib.animation import FuncAnimation
from tqdm import tqdm

# Cartopy
from cartopy import crs as ccrs
from cartopy import feature as cf
from cartopy import util as cutil
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter, LongitudeLocator, LatitudeLocator

import sys
sys.path.insert(0, '/home/disk/eos7/sressel/research/thesis-work/python/auxiliary_functions/')
import ipynb.fs.full.mjo_mean_state_diagnostics as mjo
from ipynb.fs.full.bmh_colors import bmh_colors
from ipynb.fs.full.round_out import round_out

# Define physical constants

In [None]:
################# Fundamental Constants ##################
GRAVITY = 9.81                            # g [m/s^2]
EQUIVALENT_DEPTH = 250.                   # H [m]
CORIOLIS_PARAMETER = 2.29e-11             # ß [m^-1 s^-1]
EARTH_RADIUS = 6371.0072e3                # R_e [m]
AIR_DENSITY = 1.225                       # ρ [kg m^-3]
#########################################################


### Conversion factors ###
METERS_PER_DEGREE = 110e3
SECONDS_PER_DAY = 86400
##########################


############################### Derived quantities #################################
gravity_wave_phase_speed = np.sqrt(GRAVITY*EQUIVALENT_DEPTH)         # c_g [m s^-1]
time_scale = (CORIOLIS_PARAMETER*gravity_wave_phase_speed)**(-1/2)   # T [s]
length_scale = (gravity_wave_phase_speed/CORIOLIS_PARAMETER)**(1/2)  # L [m]
friction_coefficient = 0.0*1/time_scale                              # ε 

print(f"Dry Gravity Wave Phase Speed: {gravity_wave_phase_speed:>10.2f} m/s")
print(f"Time Scale:                   {time_scale*24/SECONDS_PER_DAY:>10.2f} hours")
print(f"Length Scale:                 {length_scale/1e3:>10.2f} km")
####################################################################################

# Define simulation parameters

In [None]:
########################### Define simiulaton grid ############################
n_days                   = 30                    # number of days in simulation
n_time_steps             = 2**14                 # number of time steps
meridional_domain_length = 5000e3                # length of half y domain in m
meridional_step_size     = 100e3                 # meridional grid spacing in m
zonal_domain_length      = 2*np.pi*EARTH_RADIUS  # length of x domain in m
zonal_step_size          = 200e3                 # zonal grid spacing in m

simulation_length = n_days*SECONDS_PER_DAY       # simulation length in seconds
time_points = np.linspace(
    0, 
    simulation_length, 
    n_time_steps
)                                                # Array of simulation time points  
time_step = np.diff(time_points)[0]              # Length of a time step in s

meridional_gridpoints = np.arange(                                               
    -meridional_domain_length,
     meridional_domain_length,
     meridional_step_size
)                                                # meridional grid points

zonal_gridpoints = np.arange(                                               
    -zonal_domain_length/2,
     zonal_domain_length/2,
     zonal_step_size 
)                                                # zonal grid points 

nt = len(time_points)                            # number of time steps
ny = len(meridional_gridpoints)                  # number of zonal grid points
nx = len(zonal_gridpoints)                       # number of meridional grid points

# Calculate CFL condition
CFL_x = gravity_wave_phase_speed*time_step/zonal_step_size
CFL_y = gravity_wave_phase_speed*time_step/meridional_step_size

print(f"{'Simulation Parameters':^48}")
print(f"{'':=^48}")
print(
    f"{'Lx =':4}" + 
    f"{zonal_domain_length/1e3:>6.0f}{' km':<6}" + 
    f"{'| Δx = ':>5}" + 
    f"{zonal_step_size/1e3:>8.1f}" + 
    f"{' km':<5}" + 
    f"{'| nx = ':<5}" + 
    f"{nx:>5.0f}"
)
print(
    f"{'Ly =':4}" + 
    f"{2*meridional_domain_length/1e3:>6.0f}" + 
    f"{' km':<6}{'| Δy = ':>5}" + 
    f"{meridional_step_size/1e3:>8.1f}" + 
    f"{' km':<5}" + 
    f"{'| ny = ':<4}" + 
    f"{ny:>5.0f}"
)
print(
    f"{'T  =':4}" + 
    f"{simulation_length/SECONDS_PER_DAY:>6.0f}" + 
    f"{' days':<6}{'| Δt = ':>5}" + 
    f"{time_step:>8.1f}" + 
    f"{' sec':<5}" + 
    f"{'| nt = ':<5}" + 
    f"{nt:>5.0f}"
)
print(f"{'':=^48}")
print(f"CFL_x = {CFL_x:0.3f}", end="")
if (CFL_x < 1/(np.sqrt(2)*np.pi)):
    print(", numerically stable")
else:
    print(", CFL > 1, numerically unstable!!")

    
print(f"CFL_y = {CFL_y:0.3f}", end="")
if (CFL_y < 1/(np.sqrt(2)*np.pi)):
    print(", numerically stable")
    
else:
    print(", CFL > 1, numerically unstable!!")
    
print(f"{'':=^48}")
###########################################################################################

# Compute solutions

## Initial Conditions

The intial conditions are re-dimensionalized version of Eqs. 17 & 19 from Matsuno (1966).
The meridional structure of the mode is given by parabolic cylinder functions:

*$\psi_{m}(\hat{y}) = e^{-\frac{1}{2}\hat{y}{^2}} \times H_{m}(\hat{y})$*

where $H_{m}(\hat{y})$ is the physicist's Hermite polynomial of order $m$. The horizontal and temporal structure is taken to be wavelike:

$\begin{pmatrix}u \\ v \\ \phi\end{pmatrix}(x, t) = e^{i(kx+\omega t)}$

Where $k$ is the zonal wavenumber and $\omega$ is the frequency, and $t=0$ for the initial condition. As the equations are linear, if the intial state is balanced, the solution will remain balanced for all times. Non-balanced initial conditions will result in high-frequency gravity-waves being generated.

In [None]:
##################################### Define initial condition #####################################
# Initialize arrays
zonal_velocity = np.zeros((nt,ny,nx))
meridional_velocity = np.zeros((nt,ny,nx))
geopotential_height = np.zeros((nt,ny,nx))

# Specify wave parameters
initial_wave = 'Rossby'
initial_geopotential_anomaly = 1*100/AIR_DENSITY # m^2 s^-2
n_wavelengths = 1                                # per Earth circumference
mode_number = 1                                   
initial_wavenumber = 2*np.pi*n_wavelengths/zonal_domain_length

# Calculate wave structure
if initial_wave == 'single-perturbation':
    # Single perturbation
    geopotential_height[0] = gravity_wave_phase_speed*np.einsum(
            'i,j->ij',
            mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, 0),
            np.cos(initial_wavenumber*zonal_gridpoints)
    )
    geopotential_height[0, :, np.abs(zonal_gridpoints) > zonal_domain_length/4] = 0
    geopotential_height[geopotential_height < 0] = 0

    
elif initial_wave == 'Kelvin':
    #### Kelvin wave initial condition ####
    # u(x,y,t=0) = c × ψ(y/L, 1) × e^(ikx) [m s^-1]
    zonal_velocity[0] = gravity_wave_phase_speed*np.real(
        np.einsum(
            'i,j->ij',
            mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, 0),
            np.exp(1j*initial_wavenumber*zonal_gridpoints)
        )
    )
    
    # ϕ(x,y,t=0) = c^2 × ψ(y/L, 1) × e^(ikx) [m^2 s^-2]
    geopotential_height[0] = gravity_wave_phase_speed**2*np.real(
        np.einsum(
            'i,j->ij',
            mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, 0),
            np.exp(1j*initial_wavenumber*zonal_gridpoints)
        )
    )
    
    # Rescale the anomalies so that ϕ[0] has magnitude 'initial_geopotential_anomaly'
    zonal_velocity[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    geopotential_height[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    
elif initial_wave == 'Rossby':
    #### Rossby Wave Initial Condition ####
    initial_frequency = (
        CORIOLIS_PARAMETER*initial_wavenumber
        /(initial_wavenumber**2 + (CORIOLIS_PARAMETER/gravity_wave_phase_speed)*(2*mode_number+1))
    )
    
    # v(x,y,t=0) = i(1/ß)(ω^2-(ck)^2) × ψ(y/L, m) × e^(ikx) [m s^-1]
    meridional_velocity[0] = np.real(
            np.einsum(
                'i,j->ij',
                (
                    1j*(1/CORIOLIS_PARAMETER)*(initial_frequency**2 - gravity_wave_phase_speed**2*initial_wavenumber**2)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number)
                ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )

    # u(x,y,t=0) = L × (0.5(ω-ck)ψ(y/L, m+1) + m(ω+ck)ψ(y/L, m-1)) × e^(ikx) [m s^-1]
    zonal_velocity[0] = (gravity_wave_phase_speed/CORIOLIS_PARAMETER)**(1/2)*np.real(
            np.einsum(
                'i,j->ij',
                (
                    0.5*(initial_frequency - gravity_wave_phase_speed*initial_wavenumber)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number+1)
                    + mode_number*(initial_frequency + gravity_wave_phase_speed*initial_wavenumber)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number-1)
                ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )

    # ϕ(x,y,t=0) =  cL × (0.5(ω-ck)ψ(y/L, m+1) - m(ω+ck)ψ(y/L, m-1)) × e^(ikx) [m^2 s^-1]
    geopotential_height[0] = (gravity_wave_phase_speed/CORIOLIS_PARAMETER)**(1/2)*gravity_wave_phase_speed*np.real(
            np.einsum(
                'i,j->ij',
                (
                    0.5*(initial_frequency - gravity_wave_phase_speed*initial_wavenumber)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number+1)
                    - mode_number*(initial_frequency + gravity_wave_phase_speed*initial_wavenumber)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number-1)
                ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )
    
    # Rescale the anomalies so that ϕ[0] has magnitude 'initial_geopotential_anomaly'
    zonal_velocity[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    meridional_velocity[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    geopotential_height[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    
elif initial_wave == 'EIG':
    #### EIG Initial Condition ####
    initial_frequency = -gravity_wave_phase_speed*np.sqrt(
        initial_wavenumber**2 + (CORIOLIS_PARAMETER/gravity_wave_phase_speed)*(2*mode_number + 1)
    )

    # v(x,y,t=0) = i × (1/ß)(ω^2-(ck)^2)ψ(y/L, m) × e^(ikx) [m s^-1]
    meridional_velocity[0] = np.real(
            np.einsum(
                'i,j->ij',
            (
                1j*(1/CORIOLIS_PARAMETER)*(initial_frequency**2 - gravity_wave_phase_speed**2*initial_wavenumber**2)
                    *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number)
            ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )

    # u(x,y,t=0) = L × (0.5(ω-ck)ψ(y/L, m+1) + m(ω+ck)ψ(y/L, m-1)) × e^(ikx) [m s^-1]
    zonal_velocity[0] = (gravity_wave_phase_speed/CORIOLIS_PARAMETER)**(1/2)*np.real(
            np.einsum(
                'i,j->ij',
            (
                0.5*(initial_frequency - gravity_wave_phase_speed*initial_wavenumber)
                    *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number+1)
                + mode_number*(initial_frequency + gravity_wave_phase_speed*initial_wavenumber)
                    *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number-1)
            ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )

    # ϕ(x,y,t=0) = cL × ((0.5(ω-ck)ψ(y/L, m+1) - m(ω+ck)ψ(y/L, m-1)) × e^(ikx) [m^2 s^-2]
    geopotential_height[0] = (gravity_wave_phase_speed/CORIOLIS_PARAMETER)**(1/2)*gravity_wave_phase_speed*np.real(
            np.einsum(
                'i,j->ij',
            (
                0.5*(initial_frequency - gravity_wave_phase_speed*initial_wavenumber)
                    *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number+1)
                - mode_number*(initial_frequency + gravity_wave_phase_speed*initial_wavenumber)
                    *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number-1)
            ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )
    
    
    # Rescale the anomalies so that ϕ[0] has magnitude 'initial_geopotential_anomaly'
    zonal_velocity[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    meridional_velocity[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    geopotential_height[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    
elif initial_wave == 'WIG':
    #### WIG Initial Condition ####
    initial_frequency = gravity_wave_phase_speed*np.sqrt(
        initial_wavenumber**2 + (CORIOLIS_PARAMETER/gravity_wave_phase_speed)*(2*mode_number + 1)
    )

    # v(x,y,t=0) = i × (1/ß)(ω^2-(ck)^2)ψ(y/L, m) × e^(ikx) [m s^-1]
    meridional_velocity[0] = np.real(
            np.einsum(
                'i,j->ij',
                (
                    1j*(1/CORIOLIS_PARAMETER)*(initial_frequency**2 - gravity_wave_phase_speed**2*initial_wavenumber**2)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number)
                ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )

    # u(x,y,t=0) = L × (0.5(ω-ck)ψ(y/L, m+1) + m(ω+ck)ψ(y/L, m-1)) × e^(ikx) [m s^-1]
    zonal_velocity[0] = (gravity_wave_phase_speed/CORIOLIS_PARAMETER)**(1/2)*np.real(
            np.einsum(
                'i,j->ij',
                (
                    0.5*(initial_frequency - gravity_wave_phase_speed*initial_wavenumber)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number+1)
                    + mode_number*(initial_frequency + gravity_wave_phase_speed*initial_wavenumber)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number-1)
                ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )

    # ϕ(x,y,t=0) = cL × ((0.5(ω-ck)ψ(y/L, m+1) - m(ω+ck)ψ(y/L, m-1)) × e^(ikx) [m^2 s^-2]
    geopotential_height[0] = (gravity_wave_phase_speed/CORIOLIS_PARAMETER)**(1/2)*gravity_wave_phase_speed*np.real(
            np.einsum(
                'i,j->ij',
                (
                    0.5*(initial_frequency - gravity_wave_phase_speed*initial_wavenumber)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number+1)
                    - mode_number*(initial_frequency + gravity_wave_phase_speed*initial_wavenumber)
                        *mjo.parabolic_cylinder_function(meridional_gridpoints/length_scale, mode_number-1)
                ),
                np.exp(1j*initial_wavenumber*zonal_gridpoints)
            )
        )
    
    # Rescale the anomalies so that ϕ[0] has magnitude 'initial_geopotential_anomaly'
    zonal_velocity[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    meridional_velocity[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
    geopotential_height[0] *= initial_geopotential_anomaly/np.max(geopotential_height[0])
#####################################################################################################################

# Initialize a westward basic state everywhere
# zonal_velocity[0] += -1

##################### Specify output file directory #####################
simulation_name = (
    f"dry-Matsuno"
  + f"_k={initial_wavenumber*zonal_domain_length/(2*np.pi)}" 
  + f"_m={mode_number}"
  + f"_eps={time_scale*friction_coefficient:0.2f}"
  + f"_{initial_wave}_wave"
)
output_file_directory = (
    f"output/dry-Matsuno/{simulation_name}"
)
if not os.path.exists(output_file_directory):
    os.mkdir(output_file_directory)
print(f"Output file directory: {os.getcwd()}/{output_file_directory} \n")
#########################################################################


######################################### Plot initial conditions #########################################
# Specify plotting parameters
modified_cmap = mjo.modified_colormap('bwr', 'white', 0.15, 0.05)
plt.style.use('default')
plt.rcParams.update({'font.size':24})

# Specify quiver spacing
n_quiver_points = 10
zonal_quiver_plot_spacing = int((1/n_quiver_points/2)*zonal_domain_length/zonal_step_size)
meridional_quiver_plot_spacing = int((1/n_quiver_points)*2*meridional_domain_length/meridional_step_size)

# Calculate the equivalent pressure anomaly in hPa
equivalent_pressure = np.copy(geopotential_height)*AIR_DENSITY/100

# contour_scale = int(np.ceil(np.max(geopotential_height[0])))
contour_scale = np.max(equivalent_pressure[0])
quiver_scale = np.max(zonal_velocity[0])
grid_scaling = 1e-6

contour_args = {
    'levels' :  np.linspace(
                    -contour_scale, 
                    contour_scale, 
                    11
                ),
    # # 'levels' :  11,
    'norm'   :  mcolors.CenteredNorm(),
    'cmap'   :  modified_cmap, 
    # 'extend' :  'both'
}

quiver_args = {
    'color'       : 'k',
    'width'       : 0.0025,
    'angles'      : 'xy',
    'scale_units' : 'xy',
    'scale'       : quiver_scale
}

# Create figure
fig = plt.figure(figsize=(16,6),dpi=300)
gs = GridSpec(1, 2, width_ratios = [100,2], figure=fig)
gs.update(left=0.1, right=0.9, top=0.99, bottom=0.1, wspace=0.05)

# Create axes
cbar_ax = fig.add_subplot(gs[-1])
ax = fig.add_subplot(gs[0])

# Title 
ax.set_title('Initial Condition', pad=10)

# Plot geopotential height/pressure anomalies
cont = ax.contourf(
    zonal_gridpoints*grid_scaling,
    meridional_gridpoints*grid_scaling,
    # geopotential_height[0],
    equivalent_pressure[0],
    **contour_args
)
cbar = fig.colorbar(cont, cax=cbar_ax)
cbar.set_ticks(contour_args['levels'])
# cbar.set_label(r"$\frac{m^2}{s^2}$", rotation=0, labelpad=15)
cbar.set_label(r"$hPa$", rotation=0, labelpad=15)

# Plot wind vectors
quiv = ax.quiver(
    zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
    meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
    zonal_velocity[0,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    meridional_velocity[0,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    **quiver_args
)

ax.quiverkey(
    quiv,          
    0.79, 1.02,          
    U=quiver_scale,          
    label=f'{quiver_scale:0.2f} m/s',          
    coordinates='figure', labelpos='E',          
    color='black', labelcolor='black'
)

# Plot a line along the equator
ax.axhline(y=0, ls='--', alpha=0.5, color='gray', lw=0.5)

# Set tick labels in longitude/latitude coordinates
longitude_ticks = np.arange(-180+60, 180+60, 60)
longitude_labels = mjo.tick_labeller(longitude_ticks, direction='lon')
ax.set_xticks(longitude_ticks*METERS_PER_DEGREE*grid_scaling, labels=longitude_labels)
latitude_ticks = np.arange(-50, 50+20, 20)
latitude_labels = mjo.tick_labeller(latitude_ticks, direction='lat')
ax.set_yticks(latitude_ticks*METERS_PER_DEGREE*grid_scaling, labels=latitude_labels)

# Set aspect ratio
ax.set_aspect('auto')

# Save output figure
plt.savefig(f"{output_file_directory}/{simulation_name}_intial-condition.png", bbox_inches='tight')

## Solve

The equations are solved using the leapfrog and pseudo-spectral methods. The spatial derivatives are calculated using the pseudo-spectral method via Fourier transforms, e.g.:

$\frac{du}{dx} = iku$

The temporal derivatives are approximated using the leapfrog method, where $u[t] = F(u[t-2], \, v[t-1], \, \phi[t-1])$. The first time step is solved using the forward Euler method. 

In [None]:
########## Fourier arrays for taking derivatives using the pseudo-spectral method ##########
zonal_wavenumber      = 2*np.pi*fftfreq(nx, zonal_step_size)       # zonal wavenumbers
meridional_wavenumber = 2*np.pi*fftfreq(ny, meridional_step_size)  # meridional wavenumbers
frequencies           = 2*np.pi*fftfreq(nt, time_step)             # frequencies
############################################################################################

################################## Compute first step using forward Euler method ##################################
# Transform to wavenumber space
zonal_velocity_zonal_fft           = fft(zonal_velocity[0],      axis=1)
meridional_velocity_meridional_fft = fft(meridional_velocity[0], axis=0)
geopotential_height_zonal_fft      = fft(geopotential_height[0], axis=1)
geopotential_height_meridional_fft = fft(geopotential_height[0], axis=0)

# Compute derivatives
zonal_velocity_zonal_derivative_fft           = 1j*zonal_wavenumber[None,:]*zonal_velocity_zonal_fft
meridional_velocity_meridional_derivative_fft = 1j*meridional_wavenumber[:,None]*meridional_velocity_meridional_fft
geopotential_height_zonal_derivative_fft      = 1j*zonal_wavenumber[None,:]*geopotential_height_zonal_fft
geopotential_height_meridional_derivative_fft = 1j*meridional_wavenumber[:,None]*geopotential_height_meridional_fft
    
# Transform back to physical space
zonal_velocity_zonal_derivative           = np.real(ifft(zonal_velocity_zonal_derivative_fft,axis=1))
meridional_velocity_meridional_derivative = np.real(ifft(meridional_velocity_meridional_derivative_fft,axis=0))
geopotential_height_zonal_derivative      = np.real(ifft(geopotential_height_zonal_derivative_fft,axis=1))
geopotential_height_meridional_derivative = np.real(ifft(geopotential_height_meridional_derivative_fft,axis=0))

# Step forward one time step
# u[1] = u[0] - Δt × (-ßyv[0] + dϕ[0]/dx + εu[0])
zonal_velocity[1] = zonal_velocity[0] - time_step*(
    - CORIOLIS_PARAMETER*meridional_gridpoints[:,None]*meridional_velocity[0] 
    + geopotential_height_zonal_derivative
    + friction_coefficient*zonal_velocity[0]
)

# v[1] = v[0] - Δt × (ßyu[0] + dϕ[0]/dy + εv[0])
meridional_velocity[1] =  meridional_velocity[0] - time_step*(
    + CORIOLIS_PARAMETER*meridional_gridpoints[:,None]*zonal_velocity[0] 
    + geopotential_height_meridional_derivative
    + friction_coefficient*meridional_velocity[0]
)

# ϕ[1] = ϕ[0] -  Δt × (c^2 × (du[0]/dx + dv[0]/dy) + εϕ[0])
geopotential_height[1] = geopotential_height[0] - time_step*(
    + gravity_wave_phase_speed**2*(zonal_velocity_zonal_derivative + meridional_velocity_meridional_derivative) 
    + friction_coefficient*geopotential_height[0]
)

# Meridional boundary condition
meridional_velocity[1,0] = 0.
meridional_velocity[1,-1] = 0.
########################################################################################################################

#################################### Step forward using leapfrog time-differencing ####################################
for it in tqdm(range(2,nt), position=0, leave=True, ncols=100):
    # Transform to spectral space
    zonal_velocity_zonal_fft           = fft(zonal_velocity[it-1], axis=1)
    meridional_velocity_meridional_fft = fft(meridional_velocity[it-1], axis=0)
    geopotential_height_zonal_fft      = fft(geopotential_height[it-1], axis=1)
    geopotential_height_meridional_fft = fft(geopotential_height[it-1], axis=0)

    # Compute derivatives
    zonal_velocity_zonal_derivative_fft           = 1j*zonal_wavenumber[None,:]*zonal_velocity_zonal_fft
    meridional_velocity_meridional_derivative_fft = 1j*meridional_wavenumber[:,None]*meridional_velocity_meridional_fft
    geopotential_height_zonal_derivative_fft      = 1j*zonal_wavenumber[None,:]*geopotential_height_zonal_fft
    geopotential_height_meridional_derivative_fft = 1j*meridional_wavenumber[:,None]*geopotential_height_meridional_fft

    # Transform back to physical space
    zonal_velocity_zonal_derivative           = np.real(ifft(zonal_velocity_zonal_derivative_fft, axis=1))
    meridional_velocity_meridional_derivative = np.real(ifft(meridional_velocity_meridional_derivative_fft, axis=0))
    geopotential_height_zonal_derivative      = np.real(ifft(geopotential_height_zonal_derivative_fft, axis=1))
    geopotential_height_meridional_derivative = np.real(ifft(geopotential_height_meridional_derivative_fft, axis=0))

    # Step forward with leapfrog scheme
    # u[t] = 1/(1+2εΔt) × (u[t-2] - 2Δt × (-ßyv[t-1] + dϕ[t-1]/dx))
    zonal_velocity[it] = (1/(1+friction_coefficient*2*time_step))*(zonal_velocity[it-2] - 2*time_step*(
            - CORIOLIS_PARAMETER*meridional_gridpoints[:,None]*meridional_velocity[it-1] 
            + geopotential_height_zonal_derivative
        ))
    
    # u[t] = 1/(1+2εΔt) × (v[t-2] - 2Δt × (ßyu[t-1] + dϕ[t-1]/dy))
    meridional_velocity[it] = (1/(1+friction_coefficient*2*time_step))*(meridional_velocity[it-2] - 2*time_step*(
            + CORIOLIS_PARAMETER*meridional_gridpoints[:,None]*zonal_velocity[it-1] 
            + geopotential_height_meridional_derivative
        ))
    
    # ϕ[t] = 1/(1+2εΔt) × (ϕ[t-2] - 2Δt × c^2(du[t-1]/dx + dv[t-1]/dy))
    geopotential_height[it] = (1/(1+friction_coefficient*2*time_step))*(geopotential_height[it-2] - 2*time_step*(
            + gravity_wave_phase_speed**2*(zonal_velocity_zonal_derivative + meridional_velocity_meridional_derivative) 
        ))
    
    # Meridional boundary condition    
    meridional_velocity[it,0] = 0.
    meridional_velocity[it,-1] = 0.

# Simulation Output

## Plot solutions

### Horizontal Structure

In [None]:
# Specify plotting parameters
plt.style.use('default')
plt.rcParams.update({'font.size':24})
modified_cmap = mjo.modified_colormap('bwr', 'white', 0.15, 0.05)

# Calculate the equivalent pressure perturbation in hPa
equivalent_pressure = np.copy(geopotential_height)*AIR_DENSITY/100
# contour_scale = int(np.ceil(np.max(geopotential_height)))
contour_scale = int(np.ceil(np.max(equivalent_pressure)))
quiver_scale = int(np.ceil(np.max(zonal_velocity)))
grid_scaling = 1e-6

end_frame = -10

contour_args = {
    # 'levels' : np.arange(
    #                -np.around(gravity_wave_phase_speed, 0), 
    #                np.around(gravity_wave_phase_speed, 0)+10, 
    #                10
    #            ),
    'levels' : np.linspace(-contour_scale, contour_scale, 11),
    # 'levels' : 11, 
    'norm'   : mcolors.CenteredNorm(),
    'cmap'   : modified_cmap, 
    # 'extend' : 'both'    
}

quiver_args = {
    'color'       : 'k',
    'width'       : 0.0025,
    'angles'      : 'xy',
    'scale_units' : 'xy',
    'scale'       : quiver_scale
}

# Create figure
fig = plt.figure(figsize=(16,9),dpi=300)
gs = GridSpec(2, 2, width_ratios = [100,2], figure=fig)
gs.update(left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.05, hspace=0.25)

# Label figure
fig.suptitle('Initial and Final Conditions')

#### Initial Condition ####
# Add for the initial condition
cbar_ax = fig.add_subplot(gs[:, -1])
ax0 = fig.add_subplot(gs[0, 0])

# Plot contours of geopotential height anomalies
cont = ax0.contourf(
    zonal_gridpoints*grid_scaling,
    meridional_gridpoints*grid_scaling,
    # geopotential_height[0],
    equivalent_pressure[0],
    **contour_args
)

cbar = fig.colorbar(cont, cax=cbar_ax)
cbar.set_ticks(contour_args['levels'])
# cbar.set_label(r"$\frac{m^2}{s^2}$", rotation=0, labelpad=15)
cbar.set_label(r"hPa", labelpad=15)

# Plot vectors of wind anomalies
quiv = ax0.quiver(
    zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
    meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
    zonal_velocity[0,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    meridional_velocity[0,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    **quiver_args
)

# Add a key for the wind vectors
ax0.quiverkey(
    quiv,          
    0.79, 0.925,          
    U=quiver_scale,          
    label=f'{quiver_scale} m/s',          
    coordinates='figure', labelpos='E',          
    color='black', labelcolor='black'
)

# Plot a line on the equator
ax0.axhline(y=0, ls='--', alpha=0.5, color='gray', lw=0.5)

# Label y axis
# ax0.set_ylabel(r" y (10$^{3}$ km)")

# Set tick labels in longitude/latitude coordinates
longitude_ticks = np.arange(-180+60, 180+60, 60)
longitude_labels = mjo.tick_labeller(longitude_ticks, direction='lon')
ax0.set_xticks(longitude_ticks*METERS_PER_DEGREE*grid_scaling, labels=longitude_labels)

latitude_ticks = np.arange(-50, 50+20, 20)
latitude_labels = mjo.tick_labeller(latitude_ticks, direction='lat')
ax0.set_yticks(latitude_ticks*METERS_PER_DEGREE*grid_scaling, labels=latitude_labels)

# Set the plot aspect
ax0.set_aspect('auto')
#########################

#### Final Condition ####
# Add axes for final condition
ax1 = fig.add_subplot(gs[1, 0])

# Geopotential Height contours
cont1 = ax1.contourf(
    zonal_gridpoints*grid_scaling,
    meridional_gridpoints*grid_scaling,
    # geopotential_height[end_frame],
    equivalent_pressure[end_frame],
    **contour_args
)

# Wind vectors
quiv1 = ax1.quiver(
    zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
    meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
    zonal_velocity[end_frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    meridional_velocity[end_frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    **quiver_args
)

# Equator line
ax1.axhline(y=0, ls='--', alpha=0.5, color='gray', lw=0.5)

# Label axes
# Set tick labels in longitude/latitude coordinates
longitude_ticks = np.arange(-180+60, 180+60, 60)
longitude_labels = mjo.tick_labeller(longitude_ticks, direction='lon')
ax1.set_xticks(longitude_ticks*METERS_PER_DEGREE*grid_scaling, labels=longitude_labels)

latitude_ticks = np.arange(-50, 50+20, 20)
latitude_labels = mjo.tick_labeller(latitude_ticks, direction='lat')
ax1.set_yticks(latitude_ticks*METERS_PER_DEGREE*grid_scaling, labels=latitude_labels)


# ax1.set_xlabel(r" x (10$^{3}$ km)")
# ax1.set_ylabel(r" y (10$^{3}$ km)")

# Aspect ratio
ax1.set_aspect('auto')

plt.savefig(f"{output_file_directory}/{simulation_name}_horizontal-structure.png", bbox_inches='tight')

### Temporal Structure

In [None]:
# Specify plotting parameters
plt.style.use('bmh')
plt.rcParams.update({'font.size':24})

# Find the location of the maximum geopotential height anomaly
max_index = np.argmax(geopotential_height)
[t_index, y_index, x_index] = np.unravel_index(max_index, [nt, ny, nx])

# Create plot
[fig, ax] = plt.subplots(figsize=(16,6))
ax.set_title(f"{initial_wave} Wave Amplitude over time", pad=15)

# Plot scaled geopotential height over time
ax.plot(
    time_points/SECONDS_PER_DAY, 
    geopotential_height[:, y_index, x_index]/gravity_wave_phase_speed, 
    lw=3, 
    label=r"$\frac{\phi}{c}$"
)

# Plot zonal velocity over time
ax.plot(
    time_points/SECONDS_PER_DAY, 
    zonal_velocity[:, y_index, x_index], 
    lw=3, 
    label='u'
)

# Plot meridional velocity over time
ax.plot(
    time_points/SECONDS_PER_DAY, 
    meridional_velocity[:, y_index, x_index], 
    lw=3, 
    label='v'
)

# Plot a line at the maximum scaled geopotential height value
ax.axhline(
    y=np.max(geopotential_height)/gravity_wave_phase_speed,
    color='black',
    ls=':',
    alpha=0.75
)

# Plot a line at the minimum scaled geopotential height value
ax.axhline(
    y=np.min(geopotential_height)/gravity_wave_phase_speed,
    color='black',
    ls=':',
    alpha=0.75
)

# Label the plot and axes
ax.set_xlabel('Time (days)')
ax.set_ylabel(r"$\frac{m}{s}$", rotation=0, labelpad=20, fontsize=32)
ax.legend(loc='upper right', fontsize=18)

# Save the figure
plt.savefig(f"{output_file_directory}/{simulation_name}_temporal-structure.png", bbox_inches='tight')

### Animate solutions

In [None]:
# Specify plotting parameters
plt.style.use('default')
plt.rcParams.update({'font.size':24})

# Specify animation parameters
n_frames = 50
starting_frame = 0
frame_interval = int((nt - starting_frame)/n_frames)
# frame_interval = 2
end_frame = starting_frame + n_frames*frame_interval
frames = np.arange(starting_frame, end_frame, frame_interval)

# Calculate equivalent pressure anomaly in hPa
equivalent_pressure = np.copy(geopotential_height)*AIR_DENSITY/100

contour_scale = int(np.ceil(np.max(equivalent_pressure)))
# contour_scale = int(np.ceil(np.max(geopotential_height)))
quiver_scale = int(np.ceil(np.max(zonal_velocity)))
grid_scaling = 1e-6

contour_args = {
    'levels' : np.linspace(-contour_scale, contour_scale, 11),
    'norm'   : mcolors.CenteredNorm(),
    'cmap'   : modified_cmap, 
    # 'extend' : 'both'    
}

quiver_args = {
    'color'       : 'k',
    'width'       : 0.0025,
    'angles'      : 'xy',
    'scale_units' : 'xy',
    'scale'       : quiver_scale
}

# Create figure
fig = plt.figure(figsize=(16,6),dpi=300)
gs = GridSpec(1, 2, width_ratios = [100,2], figure=fig)
gs.update(left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.05, hspace=0.25)

# Create axes
ax = fig.add_subplot(gs[0])
cbar_ax = fig.add_subplot(gs[1])

# Title
ax.set_title(f"Time: {time_points[starting_frame]/SECONDS_PER_DAY:0.1f} days")

# Plot geopotential/pressure anomalies
cont = ax.contourf(
    zonal_gridpoints*grid_scaling,
    meridional_gridpoints*grid_scaling,
    # geopotential_height[starting_frame],
    equivalent_pressure[starting_frame],
    **contour_args
)
cbar = fig.colorbar(cont, cax=cbar_ax)
cbar.set_ticks(contour_args['levels'])
# cbar.set_label(r"$\frac{m^2}{s^2}$", rotation=0, labelpad=15)
cbar.set_label(r"hPa", rotation=0, labelpad=15)                    

# Plot wind vectors
quiv = ax.quiver(
    zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
    meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
    zonal_velocity[starting_frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    meridional_velocity[starting_frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    **quiver_args
)

ax.quiverkey(
    quiv,          
    0.79, 0.925,          
    U=quiver_scale,          
    label=f'{quiver_scale} m/s',          
    coordinates='figure', labelpos='E',          
    color='black', labelcolor='black'
)

# Plot line at the equator
ax.axhline(y=0, ls='--', alpha=0.5, color='gray', lw=0.5)

# Set tick labels in longitude/latitude coordinates
longitude_ticks = np.arange(-180+60, 180+60, 60)
longitude_labels = mjo.tick_labeller(longitude_ticks, direction='lon')
ax.set_xticks(longitude_ticks*METERS_PER_DEGREE*grid_scaling, labels=longitude_labels)

latitude_ticks = np.arange(-50, 50+20, 20)
latitude_labels = mjo.tick_labeller(latitude_ticks, direction='lat')
ax.set_yticks(latitude_ticks*METERS_PER_DEGREE*grid_scaling, labels=latitude_labels)

ax.set_aspect('auto')

def update(current_frame):
    
    # plotting_index = starting_frame + frame_interval*current_frame
    plotting_index = current_frame
    
    ax.set_title(f"Time: {time_points[plotting_index]/SECONDS_PER_DAY:0.1f} days")
    cont = ax.contourf(
        zonal_gridpoints*grid_scaling,
        meridional_gridpoints*grid_scaling,
        # geopotential_height[plotting_index],
        equivalent_pressure[plotting_index],
        **contour_args
    )
    quiv = ax.quiver(
        zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
        meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
        zonal_velocity[plotting_index,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
        meridional_velocity[plotting_index,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
        **quiver_args
    )
    return cont, quiv
    
# Create animation
anim = FuncAnimation(fig, update, frames=tqdm(frames, ncols=100, position=0, leave=True), interval=300)

# Save animation
anim.save(f"{output_file_directory}/{simulation_name}_animation.mp4", dpi=300)

## Filter Solutions

In [None]:
# Specify which wave type to plot
filtered_wave = 'Rossby'

# Filter period in days
filter_period = 5*SECONDS_PER_DAY

if filtered_wave == 'Kelvin' or filtered_wave == 'Rossby':
    time_filtered_zonal_velocity = mjo.butter_lowpass_filter(zonal_velocity.T, 1/filter_period, 1/time_step).T
    time_filtered_meridional_velocity = mjo.butter_lowpass_filter(meridional_velocity.T, 1/filter_period, 1/time_step).T
    time_filtered_geopotential_height = mjo.butter_lowpass_filter(geopotential_height.T, 1/filter_period, 1/time_step).T
    
elif filtered_wave == 'EIG' or filtered_wave == 'WIG':
    time_filtered_zonal_velocity = mjo.butter_highpass_filter(zonal_velocity.T, 1/filter_period, 1/time_step).T
    time_filtered_meridional_velocity = mjo.butter_highpass_filter(meridional_velocity.T, 1/filter_period, 1/time_step).T
    time_filtered_geopotential_height = mjo.butter_highpass_filter(geopotential_height.T, 1/filter_period, 1/time_step).T

print("Time Filtering Complete")
print("=======================")
# frequencies = np.fft.fftfreq(nt, time_step)
# phase_speed = -np.repeat(np.einsum('i,j->ji', zonal_wavenumber, frequencies)[:, np.newaxis,:], len(meridional_gridpoints), axis=1)

# # Depending on the wave type specified, filter out waves propagating in the wrong direction
# if filtered_wave == 'Kelvin' or filtered_wave == 'EIG':
#     filter_condition = np.where(zonal_wavenumber != (2*np.pi)/zonal_domain_length)
#     # filter_condition = np.where(phase_speed <= 0)
    
# elif filtered_wave == 'Rossby' or filtered_wave == 'WIG':
#     filter_condition = np.where(phase_speed >= 0)
    
# zonal_velocity_fft = np.fft.fft2(time_filtered_zonal_velocity, axes=[0,2])
# zonal_velocity_fft[filter_condition] = 0+1j*0
# zonal_velocity_filtered = np.real(np.fft.ifft2(zonal_velocity_fft, axes=[0,2]))

# meridional_velocity_fft = np.fft.fft2(time_filtered_meridional_velocity, axes=[0,2])
# meridional_velocity_fft[filter_condition] = 0+1j*0
# meridional_velocity_filtered = np.real(np.fft.ifft2(meridional_velocity_fft, axes=[0,2]))

# geopotential_height_fft = np.fft.fft2(time_filtered_geopotential_height, axes=[0,2])
# geopotential_height_fft[filter_condition] = 0+1j*0
# geopotential_height_filtered = np.real(np.fft.ifft2(geopotential_height_fft, axes=[0,2]))

# print("Direction Filtering Complete")
# print("============================")

### Plot filtered solutions

In [None]:
geopotential_height_filtered = time_filtered_geopotential_height
zonal_velocity_filtered = time_filtered_zonal_velocity
meridional_velocity_filtered = time_filtered_meridional_velocity

# Specify plotting parameters
modified_cmap = mjo.modified_colormap('bwr', 'white', 0.15, 0.05)
grid_scaling = 1e-6
quiver_scale = int(np.ceil(np.max(zonal_velocity_filtered)))
end_frame = -1

contour_args = {
    # 'levels':np.linspace(-50,50,11),
    'levels' : np.linspace(np.min(geopotential_height), np.max(geopotential_height), 11),
    'norm':mcolors.CenteredNorm(),
    'cmap':modified_cmap, 
    'extend':'both'    
}

quiver_args = {
    'color':'k',
    'width':0.0025,
    'angles':'xy',
    'scale_units':'xy',
    'scale':quiver_scale
}

# Create figure
fig = plt.figure(figsize=(16,9),dpi=300)
gs = GridSpec(2, 2, width_ratios = [100,2], figure=fig)
gs.update(left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.05, hspace=0.25)

# Label figure
fig.suptitle('Initial and Final Conditions')

#### Initial Condition ####
# Add for the initial condition
cbar_ax = fig.add_subplot(gs[:, -1])
ax0 = fig.add_subplot(gs[0, 0])

# Plot contours of geopotential height anomalies
cont = ax0.contourf(
    zonal_gridpoints*grid_scaling,
    meridional_gridpoints*grid_scaling,
    geopotential_height[0],
    **contour_args

)
cbar = fig.colorbar(cont, cax=cbar_ax)
cbar.set_ticks(contour_args['levels'])
cbar.set_label(r"$\frac{m^2}{s^2}$", rotation=0, labelpad=15)

# Plot vectors of wind anomalies
quiv = ax0.quiver(
    zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
    meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
    zonal_velocity[0,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    meridional_velocity[0,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    **quiver_args
)

# Add a key for the wind vectors
ax0.quiverkey(
    quiv,          
    0.79, 0.925,          
    U=quiver_scale,          
    label=f'{quiver_scale} m/s',          
    coordinates='figure', labelpos='E',          
    color='black', labelcolor='black'
)

# Plot a line on the equator
ax0.axhline(y=0, ls='--', alpha=0.5, color='gray', lw=0.5)

# Label y axis
ax0.set_ylabel(r" y (10$^{3}$ km)")

# Set the plot aspect
ax0.set_aspect('auto')
#########################

#### Final Condition ####
# Add axes for final condition
ax1 = fig.add_subplot(gs[1, 0])

# Geopotential Height contours
cont1 = ax1.contourf(
    zonal_gridpoints*grid_scaling,
    meridional_gridpoints*grid_scaling,
    geopotential_height_filtered[end_frame],
    **contour_args
)

# Wind vectors
quiv1 = ax1.quiver(
    zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
    meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
    zonal_velocity_filtered[end_frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    meridional_velocity_filtered[end_frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    **quiver_args
)

# Equator line
ax1.axhline(y=0, ls='--', alpha=0.5, color='gray', lw=0.5)

# Label axes
ax1.set_xlabel(r" x (10$^{3}$ km)")
ax1.set_ylabel(r" y (10$^{3}$ km)")

# Aspect ratio
ax1.set_aspect('auto')

#### Time Filtered Solution

In [None]:
filter_type = 'high-pass'
filter_period = 3*SECONDS_PER_DAY

if filter_type == 'high-pass':
    time_filtered_zonal_velocity = mjo.butter_highpass_filter(zonal_velocity.T, 1/filter_period, 1/time_step).T
    time_filtered_meridional_velocity = mjo.butter_highpass_filter(meridional_velocity.T, 1/filter_period, 1/time_step).T
    time_filtered_geopotential_height = mjo.butter_highpass_filter(geopotential_height.T, 1/filter_period, 1/time_step).T
    
elif filter_type == 'low-pass':
    time_filtered_zonal_velocity = mjo.butter_lowpass_filter(zonal_velocity.T, 1/filter_period, 1/time_step).T
    time_filtered_meridional_velocity = mjo.butter_lowpass_filter(meridional_velocity.T, 1/filter_period, 1/time_step).T
    time_filtered_geopotential_height = mjo.butter_lowpass_filter(geopotential_height.T, 1/filter_period, 1/time_step).T

##### Horizontal Structure

In [None]:
#### Plot initial conditions ####
modified_cmap = mjo.modified_colormap('bwr', 'white', 0.15, 0.05)
plt.style.use('default')
plt.rcParams.update({'font.size':24})

# Calculate the equivalent pressure anomaly in hPa
time_filtered_equivalent_pressure = np.copy(time_filtered_geopotential_height)*AIR_DENSITY/100

frame = -1000

# contour_scale = int(np.ceil(np.max(geopotential_height[0])))
contour_scale = int(np.ceil(np.max(time_filtered_equivalent_pressure[frame])))
quiver_scale = int(np.ceil(np.max(time_filtered_zonal_velocity[frame])))
# quiver_scale = 500
grid_scaling = 1e-6

contour_args = {
    # 'levels' :  np.linspace(
    #                 -contour_scale, 
    #                 contour_scale, 
    #                 11
    #             ),
    'levels' :  11,
    'norm'   :  mcolors.CenteredNorm(),
    'cmap'   :  modified_cmap, 
    # 'extend' :  'both'
}

quiver_args = {
    'color'       : 'k',
    'width'       : 0.0025,
    'angles'      : 'xy',
    'scale_units' : 'xy',
    'scale'       : quiver_scale
}

fig = plt.figure(figsize=(16,6),dpi=300)
gs = GridSpec(1, 2, width_ratios = [100,2], figure=fig)
gs.update(left=0.1, right=0.9, top=0.99, bottom=0.1, wspace=0.05)

cbar_ax = fig.add_subplot(gs[-1])
ax = fig.add_subplot(gs[0])

ax.set_title(f"{filter_period/SECONDS_PER_DAY:0.0f} day {filter_type} filtered solution", pad=10)

cont = ax.contourf(
    zonal_gridpoints*grid_scaling,
    meridional_gridpoints*grid_scaling,
    # geopotential_height[0],
    time_filtered_equivalent_pressure[frame],
    **contour_args
)
cbar = fig.colorbar(cont, cax=cbar_ax)
# cbar.set_ticks(contour_args['levels'])
# cbar.set_label(r"$\frac{m^2}{s^2}$", rotation=0, labelpad=15)
cbar.set_label(r"$hPa$", rotation=0, labelpad=15)

quiv = ax.quiver(
    zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
    meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
    time_filtered_zonal_velocity[frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    time_filtered_meridional_velocity[frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    **quiver_args
)

ax.quiverkey(
    quiv,          
    0.79, 1.02,          
    U=quiver_scale,          
    label=f'{quiver_scale} m/s',          
    coordinates='figure', labelpos='E',          
    color='black', labelcolor='black'
)

ax.set_xlabel(r"x (10$^{3}$ km)")
ax.set_ylabel(r"y (10$^{3}$ km)")

ax.axhline(y=0, ls='--', alpha=0.5, color='gray', lw=0.5)

ax.set_aspect('auto')
# ax.grid()

##### Temporal Structure

In [None]:
plt.style.use('bmh')
plt.rcParams.update({'font.size':24})

x_index = np.argmax(time_filtered_geopotential_height[0,0])
y_index = np.argmax(time_filtered_geopotential_height[0,:,x_index])

[fig, ax] = plt.subplots(figsize=(16,6))
ax.set_title(f"{filter_period/SECONDS_PER_DAY:0.0f} day {filter_type} filtered \n {initial_wave} wave amplitude over time", pad=15)

ax.plot(
    time_points/SECONDS_PER_DAY, 
    time_filtered_geopotential_height[:, y_index, x_index]/gravity_wave_phase_speed, 
    lw=3, 
    label=r"$\frac{\phi}{c}$"
)

ax.plot(
    time_points/SECONDS_PER_DAY, 
    time_filtered_zonal_velocity[:, y_index, x_index], 
    lw=3, 
    label='u'
)

ax.plot(
    time_points/SECONDS_PER_DAY, 
    time_filtered_meridional_velocity[:, y_index, x_index], 
    lw=3, 
    label='v'
)

ax.set_xlabel('Time (days)')
ax.set_ylabel(r"$\frac{m}{s}$", rotation=0, labelpad=20, fontsize=32)

ax.legend(loc='upper right', fontsize=18)

plt.show()

#### Animate filtered solutions

In [None]:
# Specify plotting parameters
plt.style.use('default')
plt.rcParams.update({'font.size':24})

n_frames = 50
starting_frame = 0
frame_interval = int((nt - starting_frame)/n_frames)
frames = np.arange(starting_frame, nt, frame_interval)

# Calculate equivalent pressure anomaly in hPa
equivalent_pressure = np.copy(geopotential_height)*AIR_DENSITY/100

contour_scale = int(np.ceil(np.max(equivalent_pressure)))
# contour_scale = int(np.ceil(np.max(geopotential_height)))
quiver_scale = int(np.ceil(np.max(zonal_velocity)))
grid_scaling = 1e-6

contour_args = {
    'levels' : np.linspace(-contour_scale, contour_scale, 11),
    'norm'   : mcolors.CenteredNorm(),
    'cmap'   : modified_cmap, 
    # 'extend' : 'both'    
}

quiver_args = {
    'color'       : 'k',
    'width'       : 0.0025,
    'angles'      : 'xy',
    'scale_units' : 'xy',
    'scale'       : quiver_scale
}

fig = plt.figure(figsize=(16,6),dpi=300)
gs = GridSpec(1, 2, width_ratios = [100,2], figure=fig)
gs.update(left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.05, hspace=0.25)

ax = fig.add_subplot(gs[0])
cbar_ax = fig.add_subplot(gs[1])

ax.set_title(f"Time: {time_points[starting_frame]/SECONDS_PER_DAY:0.1f} days")

cont = ax.contourf(
    zonal_gridpoints*grid_scaling,
    meridional_gridpoints*grid_scaling,
    # geopotential_height[starting_frame],
    time_filtered_equivalent_pressure[starting_frame],
    **contour_args
)
cbar = fig.colorbar(cont, cax=cbar_ax)
cbar.set_ticks(contour_args['levels'])
# cbar.set_label(r"$\frac{m^2}{s^2}$", rotation=0, labelpad=15)
cbar.set_label(r"hPa", rotation=0, labelpad=15)                    

quiv = ax.quiver(
    zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
    meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
    time_filtered_zonal_velocity[starting_frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    time_filtered_meridional_velocity[starting_frame,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
    **quiver_args
)

ax.quiverkey(
    quiv,          
    0.79, 0.925,          
    U=quiver_scale,          
    label=f'{quiver_scale} m/s',          
    coordinates='figure', labelpos='E',          
    color='black', labelcolor='black'
)

ax.set_xlabel('x ($10^3$ km)')
ax.set_ylabel('y ($10^3$ km)')
ax.set_aspect('auto')

ax.axhline(y=0, ls='--', alpha=0.5, color='gray', lw=0.5)

def update(current_frame):
    
    # plotting_index = starting_frame + frame_interval*current_frame
    plotting_index = current_frame
    
    ax.set_title(f"Time: {time_points[plotting_index]/SECONDS_PER_DAY:0.1f} days")
    cont = ax.contourf(
        zonal_gridpoints*grid_scaling,
        meridional_gridpoints*grid_scaling,
        # geopotential_height[plotting_index],
        time_filtered_equivalent_pressure[plotting_index],
        **contour_args
    )
    quiv = ax.quiver(
        zonal_gridpoints[::zonal_quiver_plot_spacing]*grid_scaling,
        meridional_gridpoints[::meridional_quiver_plot_spacing]*grid_scaling,
        time_filtered_zonal_velocity[plotting_index,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
        time_filtered_meridional_velocity[plotting_index,::meridional_quiver_plot_spacing,::zonal_quiver_plot_spacing],
        **quiver_args
    )
    return cont, quiv
    
anim = FuncAnimation(fig, update, frames=tqdm(frames, ncols=100, position=0, leave=True), interval=300)

anim_dir = (
    r"/home/disk/eos7/sressel/research/thesis-work/python/numerical_solver/numerical_solver_output/dry-Matsuno/"
)
anim_file_name = (
    "dry-Matsuno" 
    + f"_k={initial_wavenumber*zonal_domain_length}" 
    + f"_m={mode_number}" 
    + f"_{filter_period/SECONDS_PER_DAY:0.0f}-day" 
    + f"_{filter_type}-filtered_" 
    + f"{initial_wave}_wave.mp4"
)
anim.save(anim_dir + anim_file_name, dpi=300)

## Calculate amplification factor

In [None]:
[y_max, x_max] = np.unravel_index(
    np.argmax(geopotential_height[0]), 
    np.shape(geopotential_height[0])
)
amplification_factor = [geopotential_height[i, y_max, x_max]/geopotential_height[i-1, y_max, x_max] for i in range(1, nt)]
# amplification_factor = [
#     np.max(time_filtered_geopotential_height[i])/np.max(time_filtered_geopotential_height[i-1]) for i in range(1, nt)
# ]


plt.style.use('bmh')
plt.rcParams.update({'font.size':22})
[fig, ax] = plt.subplots(1, 1, figsize=(16,5))
ax.set_title('Amplification Factor over Time', pad=15)

ax.plot(
    time_points[1:]/SECONDS_PER_DAY, 
    amplification_factor
)

ax.axhline(y=1, color='gray', alpha=0.75, ls=':')

ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=n_days, prune='lower'))

ax.set_xlabel('Time (days)')
ax.set_ylabel(r"$\frac{ϕ_{n+1}}{ϕ_{n}}$", rotation=0, labelpad=40, fontsize=38)
        
# ax.set_ylim(0.99, 1.01)
ax.set_xlim(0, n_days)

fig.savefig(f"{output_file_directory}/{simulation_name}_amplification-factor.png", dpi=300)