# Wave Filter Process

In [None]:

import time
import sys
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import dask.array as da
import netCDF4
from scipy import signal

# Create a class for wave filtering
class WaveFilter:
    def __init__(self, ds, var, coverage, wave_name, n, units):
        #self.path       = path
        self.var        = var
        self.coverage  = coverage
        self.wave_name  = wave_name
        self.long_name  = None
        self.n          = n          

        self.ds         = ds #xr.open_dataset(path,chunks={'time': 'auto'})  # 使用 dask 分块加载
        self.data       = None
        self.units      = units

        self.filtered_data = None
        self.fftdata       = None

    def load_data(self):
        """Load and preprocess the data"""
        var = self.var      # data variable

        # adjust the dimension
        if "latitude" in self.ds.coords:
            self.ds = self.ds.rename({'latitude':'lat'})
        if "longitude" in self.ds.coords:
            self.ds = self.ds.rename({'longitude':'lon'})
        if "depth" in self.ds.coords:
            self.ds = self.ds.mean("depth")

        # load the data
        self.data = self.ds[var].sel(**self.coverage).sortby('lat').transpose('time', 'lat', 'lon')
        self.data = self.data.fillna(0)             # fill NaN value with 0 to avoid error
        self.data = self.data.chunk({'time': -1})   # chunk the data for parallel processing
        
    def detrend_data(self):
        """Detrend the data using dask for parallel processing."""
        ntim, nlat, nlon = self.data.shape
        spd = 1 # number of sample per day
        #data_rechunked = self.data.data.rechunk({0: -1})
        data_rechunked = self.data.data.rechunk({0: -1, 2: -1})
        if ntim >  365*spd/3:
            # FFT
            rf   = da.fft.rfft(data_rechunked, axis=0)
            freq = da.fft.rfftfreq(ntim * spd, d=1. / float(spd))
            rf[(freq <= 3. / 365) & (freq >= 1. / 365), :, :] = 0.0
            datain = da.fft.irfft(rf, axis=0, n=ntim)
 
        # detrend the data
        self.detrend = da.apply_along_axis(signal.detrend, 0, datain)    
        # apply window function 
        window = signal.windows.tukey(ntim,0.05,True)
        self.detrend = self.detrend * window[:, np.newaxis, np.newaxis]   

    def fft_transform(self):
        """Perform 2D FFT on the detrended data using dask."""
        self.wavenumber = -da.fft.fftfreq(self.data.shape[2]) * self.data.shape[2]   # shape: (lon,)
        self.frequency  = da.fft.fftfreq(self.data.shape[0], d=1./float(1))          # shape: (time,)

        self.knum_ori, self.freq_ori = da.meshgrid(self.wavenumber, self.frequency)  # shape: (time, lon)
        self.knum = self.knum_ori.copy()
        self.knum = da.where(self.freq_ori < 0, -self.knum_ori, self.knum_ori)       # shape: (time, lon)
        
        self.freq = da.abs(self.freq_ori)    # shape: (time, lon)
    
    def apply_filter(self):
        """Apply filter based on wave type."""
        if self.wave_name.lower() == "kw":
            self.tMin, self.tMax = 20, 180
            self.kmin, self.kmax = 1, 14
            self.hmin, self.hmax = None, None #0.025, 90
            
        elif self.wave_name.lower() == "er":
            self.tMin, self.tMax = 180, 450
            self.kmin, self.kmax = -10, -1
            self.hmin, self.hmax = None, None #0.003, 90
            
        print(f"T: {self.tMin}-{self.tMax} days | k: {self.kmin}-{self.kmax} | h = {self.hmin}-{self.hmax} m")
        
        self.fmin, self.fmax = 1 / self.tMax, 1 / self.tMin
        self.mask =  da.zeros((self.data.shape[0], self.data.shape[2]), dtype=bool)

        if self.kmin is not None:
            self.mask = self.mask | (self.knum < self.kmin)
        if self.kmax is not None:
            self.mask = self.mask | (self.kmax < self.knum)

        if self.fmin is not None:
            self.mask = self.mask | (self.freq < self.fmin)
        if self.fmax is not None:
            self.mask = self.mask | (self.fmax < self.freq)

        if self.wave_name.lower() == 'kw':
            self.apply_wave_filter(self.wave_name)
        elif self.wave_name.lower() == 'er':
            self.apply_wave_filter(self.wave_name)
            
        self.fftdata = da.fft.fft2(self.detrend, axes=(0, 2)) # shape: (time, lat, lon)
        self.mask    = da.repeat(self.mask[:, np.newaxis, :], self.data.shape[1], axis=1)
        self.fftdata = da.where(self.mask, 0.0, self.fftdata)

    def apply_wave_filter(self, wave_name):
        """Apply equtorial wave filter."""
        # parameters
        g    = 9.8
        beta = 2.28e-11
        a    = 6.37e6
        n    = self.n

        if self.wave_name.lower() == "kw":
            if self.hmin is not None:
                c      = da.sqrt(g * self.hmin)
                omega  = 2. * np.pi * self.freq / 24. / 3600. / da.sqrt(beta * c)
                k      = self.knum / a * da.sqrt(c / beta)
                self.mask = self.mask | (omega - k < 0)
            if self.hmax is not None:
                c      = da.sqrt(g * self.hmax)
                omega  = 2. * np.pi * self.freq / 24. / 3600. / da.sqrt(beta * c)
                k      = self.knum / a * da.sqrt(c / beta)
                self.mask = self.mask | (omega - k > 0)
    
        if self.wave_name.lower() == "er": 
            if self.hmin is not None:
                c = da.sqrt(g * self.hmin)
                omega = 2. * np.pi * self.freq / 24. / 3600. / da.sqrt(beta * c)
                k = self.knum / a * da.sqrt(c / beta)
                self.mask = self.mask | (omega * (k ** 2 + (2 * n + 1)) + k < 0)
            if self.hmax is not None:
                c = da.sqrt(g * self.hmax)
                omega = 2. * np.pi * self.freq / 24. / 3600. / da.sqrt(beta * c)
                k = self.knum / a * da.sqrt(c / beta)
                self.mask = self.mask | (omega * (k ** 2 + (2 * n + 1)) + k > 0)
         
    def inverse_fft(self):
        """Perform inverse FFT to get the filtered data."""
        self.filtered_data = da.fft.ifft2(self.fftdata, axes=(0, 2)).real
    
    def create_output(self):
        """Create xarray DataArray for filtered data."""
    
        if self.wave_name == 'KW':
            self.long_name = 'Kelvin Waves'
        elif 'ER' in self.wave_name:
            self.long_name = 'Equatorial Rossby Waves'
        else:
            self.wave_name = None
        
        self.wave_data = xr.DataArray(self.filtered_data.compute(),
                                      coords = {'time': self.data.time,
                                                'lat' : self.data.lat,
                                                'lon' : self.data.lon},
                                      dims=['time', 'lat', 'lon'])
        self.wave_data.attrs.update({
            'name'           : self.wave_name,
            'long_name'      : self.long_name,
            'min_wavenumber' : self.kmin,
            'max_wavenumber' : self.kmax,
            'min_period'     : self.tMin,
            'max_period'     : self.tMax,
            'min_frequency'  : self.fmin,
            'max_frequency'  : self.fmax,
            'units'          : self.units,
        })
        
        self.ds.close()
        return self.wave_data

In [None]:
# Main program 
start=time.time()

print("BEGIN")
print(sys.version)

# Load the dataset
# file paths
file_address = {
    "2023": {"uo":"/kaggle/input/wave-energy-analysis/Raw Data/u_current.2023_2024.nc", 
             "vo":"/kaggle/input/wave-energy-analysis/Raw Data/v_current.2023_2024.nc"},
    "2018": {"uo":"/kaggle/input/wave-energy-analysis/u_current.2018_2019.nc", 
             "vo":"/kaggle/input/wave-energy-analysis/v_current.2018_2019.nc"},
    "2014": {"uo":"/kaggle/input/wave-energy-analysis/u_current.2014_2015.nc", 
             "vo":"/kaggle/input/wave-energy-analysis/v_current.2014_2015.nc"}
}
# Region of Interest
coverage = {
    'lat' : slice(-10, 10), 
    'lon' : slice(150, 270)
}
event = '2023'

var = "uo"

equ_waves = {
    'KW'  : {'wave_name':'KW', 'n':0, 'data':[]}, 
    'ER' : {'wave_name':'ER',  'n':1, 'data':[]}
}

# Results 
results = {
    'uo': None,
    'vo': None
}


for num, (var, path) in enumerate(file_address[event].items()):
    print(f"Variable {num+1}:", var)
    print(f"path: {path}")
    # Raw Data
    ds = xr.open_dataset(path, chunks="auto").mean('depth')

    # Climatology Data
    ds_clim = xr.open_dataset("/kaggle/input/wave-energy-analysis/climatology.1993_2021.nc", chunks="auto")
    clim = ds_clim[var].mean(['time', 'depth'])

    # Anomaly Data
    ds_var, clim_aligned = xr.align(ds[var][:,:,:1920], clim, join='override')
    anom = ds_var - clim_aligned
    data = anom.to_dataset()

    
    for i, key in enumerate(equ_waves.keys()):
        wave_name = equ_waves[key]['wave_name']
        n         = equ_waves[key]['n']
        print(f"{i+1}: {wave_name}, n = {n}")
    
        # create a wave filter object
        wave_filter = WaveFilter(data, var, coverage, wave_name, n, units='m')
    
        # filtering process
        wave_filter.load_data()  
        wave_filter.detrend_data()
        wave_filter.fft_transform()
        wave_filter.apply_filter()
        wave_filter.inverse_fft()
    
        # store the data
        equ_waves[key]['data'] = wave_filter.create_output()
    
        print(f"{i+1}/{len(equ_waves)} Completed")

    # CREATE DATASET
    # Extract the dimension
    t    = equ_waves['ER']['data'].time
    lat  = equ_waves['ER']['data'].lat
    lon  = equ_waves['ER']['data'].lon
    
    # Initialize the dataset
    waves = xr.Dataset(
        data_vars={
            key: (['time', 'lat', 'lon'], equ_waves[key]['data'].values, equ_waves[key]['data'].attrs)
            for key in equ_waves if isinstance(equ_waves[key]['data'], xr.DataArray)
        },
        coords=dict(
            time=t,
            lat=lat,
            lon=lon,
        ),
        attrs=dict(
            description=f'FFT-Filtered Sea Current of {event}',
            source = 'Global Ocean Physics Reanalysis',
            unit='m/s'
        )
    )
    
    results[var] = waves 
    
    # Export the data
    comp = dict(zlib=True, complevel=5)
    encoding = {var: comp for var in waves.data_vars}
    #waves.to_netcdf(
    #    f'EW.SLA.{event}.noaa.nc',
    #    engine='netcdf4',
    #    encoding=encoding
    #)
    waves.close()
    ds_clim.close()
    ds.close()

end = time.time()
print("Runtime: %8.1f seconds." % (end-start))
print('DONE!')

# Calculate the Energy Budget

In [None]:
def calculate_energy(u, v, h):
    """ 
    Calculate Kinetic and Potential Energy of Equatorial Waves
    """
    g = 9.81

    
    energy = {
        'KW': None,
        'ER': None
    }

    for var in ['KW', 'ER']:
        print(var)

        # group velocity
        if var == "KW":
            c = 2.25
        else:
            c = 0.55

        # calculate the energy
        energy[var] = 0.5 * ( u[var]**2 + v[var]**2 + (g * h[var])**2 )
        #energy[var] = 0.5 * (u[var]**2 + v[var]**2 + (g * h[var])**2)
    # export as dataset
    ds = xr.Dataset({
        'KW': energy['KW'],
        'ER': energy['ER']
    })
    
    return ds

In [None]:
event = "2023"

# sea level anomaly results
waves_path = {
    "2023" : "/kaggle/input/eof-results/2023-2024/2023-2024.EW.SLA.nc",
    "2018" : "/kaggle/input/eof-results/2018-2019/2018-2019.EW.SLA.nc",
    "2014" : "/kaggle/input/eof-results/2014-2015/2014-2015.EW.SLA.nc"
}
equ_waves = xr.open_dataset(waves_path[event], chunks="auto")


coverage = {
    "2023" : dict(time = slice("2023-01", "2024-12")),
    "2018" : dict(time = slice("2018-01", "2019-12")),
    "2014" : dict(time = slice("2014-01", "2015-12"))
    
}

# extract the data
u = results["uo"].sel(**coverage[event])
v = results["vo"].sel(**coverage[event])
h = equ_waves.sel(**coverage[event])

# calculate the energy
energy = calculate_energy(u, v, h)

In [None]:
# Export the data
comp = dict(zlib=True, complevel=7)
encoding = {var: comp for var in energy.data_vars}
energy.to_netcdf(
   f'{event}.Wave_Energy.nc',
   engine='netcdf4',
   encoding=encoding
)

# Data Visualization

In [None]:
import xarray as xr
import numpy as np

wave_energy = {
    "2023": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2023.Wave_Energy.nc", chunks="auto"),
        period = dict(time=slice("2023-06", "2024-05")),
        eof_kw = xr.open_dataset("/kaggle/input/eof-results/Full Year/2023-2024.EOF.KW.nc"), 
        eof_er = xr.open_dataset("/kaggle/input/eof-results/Full Year/2023-2024.EOF.ER.nc")
    ),
    "2018": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2018.Wave_Energy.nc", chunks="auto"),
        period = dict(time=slice("2018-10", "2019-07")),
        eof_kw = xr.open_dataset("/kaggle/input/eof-results/Full Year/2018-2019.EOF.KW.nc"), 
        eof_er = xr.open_dataset("/kaggle/input/eof-results/Full Year/2018-2019.EOF.ER.nc")
    ),
    "2014": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2014.Wave_Energy.nc", chunks="auto"), 
        period = dict(time=slice("2014-11", "2015-06")),
        eof_kw = xr.open_dataset("/kaggle/input/eof-results/Full Year/2014-2015.EOF.KW.nc"), 
        eof_er = xr.open_dataset("/kaggle/input/eof-results/Full Year/2014-2015.EOF.ER.nc")
    )
}

## Monthly Energy Budget

In [None]:
import matplotlib.pyplot as plt 
import numpy as np 

colors = ['tab:blue', 'tab:orange', 'tab:green']

#fig, ax = plt.subplots(figsize=(10, 9), nrows=3, sharex=False, constrained_layout=True)

t = np.arange(1, 25)

time = wave_energy['2023']['energy']['time']
var = "ER"

if var == "KW":
    lat_band = slice(-2,2)
    title = 'Kelvin Waves'
else:
    lat_band = slice(-5,5)
    title = 'Equatorial Rossby Waves'

budget2023 = wave_energy['2023']['energy'][var].sel(lat=lat_band).resample(time='1ME').sum().sum(['lat','lon']).values
budget2018 = wave_energy['2018']['energy'][var].sel(lat=lat_band).resample(time='1ME').sum().sum(['lat','lon']).values
budget2014 = wave_energy['2014']['energy'][var].sel(lat=lat_band).resample(time='1ME').sum().sum(['lat','lon']).values

In [None]:
# Composite energy budget for the 2018 and 2014 events

mean = []
std  = []

for i, (b1, b2) in enumerate(zip(budget2018, budget2014)):
    composite = [b1, b2]
    mean.append(np.mean(composite))
    std.append(np.std(composite))

In [None]:
x = np.arange(1,72, 1)
x1 = x[::3]


ticks = np.arange(1, 25, 1)

In [None]:
factor = 1e5
data1 = budget2023 / factor
data2 = np.array(mean) / factor
yerr2 = np.array(std) / factor

fig, ax = plt.subplots(figsize=(10,4), dpi=300)
ax.bar(x1, data1, color='tab:blue', label='2023/24 El Niño', align='edge', width=-1)
ax.bar(x1, data2, yerr=yerr2, capsize=4, color='tab:grey', label='Other El Niño', align='edge', width=1)

ax.set_xticks(x1)
ax.set_xticklabels(ticks)
ax.legend()
ax.set_title(f"{title} Energy Budget", fontweight='bold')
ax.set_ylabel("10⁵ W/m²")
ax.set_xlabel("Month-")

plt.show()
fig.savefig(f"Energy.Budget.{var}.png", dpi=300, bbox_inches='tight')

## Wave Energy vs Taux

In [None]:
import pandas as pd 
import numpy as np 
import xarray as xr

# Wave Energy Dataset
wave_energy = {
    "2023": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2023.Wave_Energy.nc", chunks="auto"),
        period = dict(time=slice("2023-06", "2024-05"))
    ),
    "2018": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2018.Wave_Energy.nc", chunks="auto"),
        period = dict(time=slice("2018-10", "2019-07"))
    ),
    "2014": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2014.Wave_Energy.nc", chunks="auto"), 
        period = dict(time=slice("2014-11", "2015-06"))
    )
}

# Wind Stress Dataset
wind_stress = xr.open_dataset("/kaggle/input/warmwatervolume/wind_stress.1994_2025.ERA5.nc")
wind_stress = wind_stress.rename({"latitude":"lat", "longitude":"lon"})

# Zonal Wind Stress
taux = wind_stress['eastward_stress']
clim_taux = taux.sel(time=slice("1994-01", "2021-12")).mean("time")

# Taux anomaly
taux_2023 = taux.sel(time=slice("2023-01", "2024-12")) - clim_taux 
taux_2018 = taux.sel(time=slice("2018-01", "2019-12")) - clim_taux
taux_2014 = taux.sel(time=slice("2014-01", "2015-12")) - clim_taux

In [None]:
def taux_idx(years, data, lat, lon, time):

    taux_index = {
        "2023": None,
        "2018": None,
        "2014": None
    }
    
    for y, dat in zip(years, data):
        idx = np.mean(dat.sel(lat=lat, lon=lon).isel(time=time))
        taux_index[y] = float(idx.values)

    return taux_index

years = ["2023", "2018", "2014"]
datas = [taux_2023, taux_2018, taux_2014]

taux_index_er = taux_idx(years, datas, slice(-10,10), slice(150, 270), slice(0, 25))
taux_index_kw = taux_idx(years, datas, slice(-2, 2),  slice(150, 270), slice(10, 13))

In [None]:
total_energy = {
    '2023': {'KW':None, 'ER':None},
    '2018': {'KW':None, 'ER':None},
    '2014': {'KW':None, 'ER':None}
}

for year in ["2023", "2018", "2014"]:

    yrstart = int(year) 
    yrstop  = int(int(year) + 1)
    
    for var in ['KW', 'ER']:
        if var == 'KW':
            we = wave_energy[year]["energy"][var].sel(lat=slice(-2,2), time=slice(f"{yrstart}-11", f"{yrstop}-01"))
            we = we.sum(['time', 'lat','lon']).values
        else:
            we = wave_energy[year]["energy"][var].sel(lat=slice(-10,10)).sum(['time', 'lat','lon']).values
        total_energy[year][var] = we
        print(year, var, we)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8,4), ncols=2, dpi=300, constrained_layout=True)

colors = ['tab:blue', 'tab:orange', 'tab:green']

for i, (year, data) in enumerate(total_energy.items()):
    print(year)
    yrtext = f"{year[2:]}/{str(int(year)+1)[2:]}"

    xfactor = 1e-2
    yfactor = 1e5
    
    # ER
    ere = data['ER'] / yfactor
    taux = taux_index_er[year] / xfactor 
    ax[1].scatter(taux, ere, color=colors[i], zorder=2)
    print(taux)
    # Annotate each point
    if i == 2:
        ax[1].annotate(yrtext, (taux-0.1, ere), textcoords="offset points",  xytext=(0, 0), ha="left", fontsize=8)
    else:
        ax[1].annotate(yrtext, (taux+0.02, ere), textcoords="offset points",  xytext=(0, 0), ha="left", fontsize=8)
        
    # KW
    kwe = data['KW'] / yfactor
    taux = taux_index_kw[year] /xfactor
    print(taux)
    ax[0].scatter(taux, kwe, color=colors[i], zorder=2)
    if i == 2:
        ax[0].annotate(yrtext, (taux, kwe+0.02), textcoords="offset points",  xytext=(-5, 0), ha="left", fontsize=8)
    else:
        ax[0].annotate(yrtext, (taux, kwe), textcoords="offset points",  xytext=(-5, 0), ha="right", fontsize=8)
        
    ax[0].set_title(r"(a) KW Energy vs $\tau_{x}$ Index", fontweight='bold')
    ax[1].set_title(r"(b) ER Energy vs $\tau_{x}$ Index", fontweight='bold')
    
    if i != 2:
        
        if i == 0:
            ax[i].set_xlim(-0.4, 0.6)
            ax[i].set_xticks([-0.4, -0.2, 0.0, 0.2, 0.4, 0.6])
            ax[i].set_xlabel(r"Nov$^{0}$-Jan$^{1}$ $\tau_{x}$ Index ($10^{-2}$ N/m$^2$)")
            ax[i].set_ylabel(r"Nov$^{0}$-Jan$^{1}$ Wave Energy ($10^{5}$ W/m$^2$)")
        else:
            ax[i].set_xlabel(r"$\tau_{x}$ Index ($10^{-2}$ N/m$^2$)")
            ax[i].set_ylabel(r"Wave Energy ($10^{5}$ W/m$^2$)")
            
        ax[i].axvline(0, color='black', alpha=0.8, zorder=1)


for i in range(0,2):    
    ax[i].grid(ls="--", color='k', alpha=0.2, zorder=0)
    
plt.show()
#fig.savefig("WaveEnergy_vs_Taux.png", dpi=300, bbox_inches="tight")

## Wave Energy Composite

In [None]:
import xarray as xr
import numpy as np

wave_energy = {
    "2023": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2023.Wave_Energy.nc", chunks="auto"),
        period = dict(time=slice("2023-06", "2024-05")),
        eof_kw = xr.open_dataset("/kaggle/input/eof-results/Full Year/2023-2024.EOF.KW.nc"), 
        eof_er = xr.open_dataset("/kaggle/input/eof-results/Full Year/2023-2024.EOF.ER.nc")
    ),
    "2018": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2018.Wave_Energy.nc", chunks="auto"),
        period = dict(time=slice("2018-10", "2019-07")),
        eof_kw = xr.open_dataset("/kaggle/input/eof-results/Full Year/2018-2019.EOF.KW.nc"), 
        eof_er = xr.open_dataset("/kaggle/input/eof-results/Full Year/2018-2019.EOF.ER.nc")
    ),
    "2014": dict(
        energy = xr.open_dataset("/kaggle/input/wave-energy-analysis/2014.Wave_Energy.nc", chunks="auto"), 
        period = dict(time=slice("2014-11", "2015-06")),
        eof_kw = xr.open_dataset("/kaggle/input/eof-results/Full Year/2014-2015.EOF.KW.nc"), 
        eof_er = xr.open_dataset("/kaggle/input/eof-results/Full Year/2014-2015.EOF.ER.nc")
    )
}

In [None]:
import matplotlib.pyplot as plt 
import numpy as np 
from matplotlib.ticker import FuncFormatter
import cmocean

cmap = cmocean.cm.dense

# Axes format
lat_format = lambda v, _: f'{v:.0f}°N' if v > 0 else (f'{abs(v):.0f}°S' if v < 0 else '0')

# Axis formatter
def lon_format(v, _):
    # v is longitude in degrees east (0-360)
    if v > 180:
        return f'{360 - v:.0f}°W'
    elif v < 180:
        return f'{v:.0f}°E'
    else:
        return '180°'


fig, ax = plt.subplots(figsize=(10, 9), nrows=3, sharey=True, constrained_layout=True, dpi=500)

cbarticks = np.linspace(0, 5, 11)

titles = ["(a) 2023/24 El Niño", 
          "(b) 2018/19 El Niño", 
          "(c) 2014/15 El Niño"]

var = "KW"

for n, (year, data) in enumerate(wave_energy.items()):
    idx      =  data['eof_kw']['amp']
    idx_mean = np.mean(idx.values)
    idx_std  = np.std(idx.values)
    limit    = idx_mean + idx_std
    time = data["energy"][var]["time"]

    if var == "KW":
        title = f"{titles[n]}: Kelvin Waves Energy"
        selection = slice(f"{int(year)}-11", f"{int(year)+1}-01")
    else:
        title = f"{titles[n]}: Equatorial Rossby Waves Energy"
        selection = slice(f"{int(year)}-01", f"{int(year)+1}-12") 
   
    composite = data["energy"][var].sel(time=selection) #.where(time[idx > 1])
    print(year)
    print(f"{var}: {idx_mean:.2f} + {idx_std:.2f}")
    print("Total initiation: ", composite.shape[0])
    composite = composite.mean(dim="time") * 100
    #composite = composite.where(time[idx > limit]).mean("time")

    lon = composite['lon']
    lat = composite['lat']
    c = ax[n].contourf(lon, lat, composite, cmap="Reds", levels=cbarticks, extend='both')
        
    ax[n].set_title(title, fontweight="bold")
    ax[n].xaxis.set_major_formatter(FuncFormatter(lon_format))
    ax[n].yaxis.set_major_formatter(FuncFormatter(lat_format))
    ax[n].set_ylabel("Latitude")
    if n == 2:
        ax[n].set_xlabel("Longitude")
    
cbar = fig.colorbar(c, ax=ax, orientation='horizontal', aspect=40, pad=0.03, ticks=cbarticks[::2], label=r"$10^{-2}$ W/m$^2$")
plt.show()
fig.savefig(f"Composite.Wave_Energy.{var}.png", dpi=500, bbox_inches="tight")