# ONI / NINO-3.4 Variability Notebook

J. Krasting -- NOAA/GFDL

In [None]:
# Development mode: constantly refreshes module code
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["ESNB_LOG_LEVEL"] = "INFO"

## Framework Code and Diagnostic Setup

In [None]:
import esnb
from esnb import NotebookDiagnostic, RequestedVariable, CaseGroup2
from esnb.sites.gfdl import call_dmget

In [None]:
diag_name = "ONI Variability"
diag_desc = "ONI and Nino34 Variability Diagnostics"
variables = [RequestedVariable("tos", "ocean_month")]
user_options = {"enso_region": ["nino12", "nino3", "nino34", "nino4"]}
workdir = "/vftmp/John.Krasting/cm5-enso-20250904"
diag = NotebookDiagnostic(
    diag_name, diag_desc, variables=variables, workdir=workdir, **user_options
)

In [None]:
groups = [
    CaseGroup2("895", date_range=("0001-01-01", "0500-12-31"), name="CM4.0 (895)", plot_color="blue"),
    CaseGroup2("2916", date_range=("0001-01-01", "0150-12-31"), name="OM4_D5 (2916)", plot_color="orange"),
    CaseGroup2("3031", date_range=("0001-01-01", "0150-12-31"), name="B11_D5 (3031)", plot_color="green"),
    CaseGroup2("esm45-109", date_range=("0001-01-01", "0250-12-31"), name="ESM4.5 (esm45-109)", plot_color="purple"),
]

In [None]:
diag.resolve(groups)

In [None]:
diag.open(use_cache=True)

In [None]:
diag.write_cache()

## Begin the User Diagnostic Code

#### Load Modules

In [None]:
import os

import cftime
import matplotlib.pyplot as plt
import momlevel as ml
import numpy as np
import xarray as xr
from momgrid.geoslice import geoslice

In [None]:
esnb.sites.gfdl.convert_to_momgrid(diag)

#### Custom functions

In [None]:
# get the custom variable from the diagnostic settings
enso_region = diag.diag_vars.get("enso_region", None)

In [None]:
varname = "sst"
xdim = "lon"
ydim = "lat"

time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
dsobs = xr.open_dataset("/home/jpk/NOAA.ER.v2.sst.nc", decode_times=time_coder)
dsobs["area"] = ml.util.standard_grid_cell_area(dsobs[ydim], dsobs[xdim])

dsobs = dsobs.sel(time=slice("1891-01-01", None))

obs_enso_ts = {}
for region in enso_region:
    ds = dsobs
    if region == "nino12":
        tos = ds[varname].sel({xdim: slice(270, 280), ydim: slice(-10, 0)})
        area = ds["area"].sel({xdim: slice(270, 280), ydim: slice(-10, 0)})
    elif region == "nino3":
        tos = ds[varname].sel({xdim: slice(210, 270), ydim: slice(-5, 5)})
        area = ds["area"].sel({xdim: slice(210, 270), ydim: slice(-5, 5)})
    elif region == "nino34":
        tos = ds[varname].sel({xdim: slice(190, 240), ydim: slice(-5, 5)})
        area = ds["area"].sel({xdim: slice(190, 240), ydim: slice(-5, 5)})
    elif region == "nino4":
        tos = ds[varname].sel({xdim: slice(170, 210), ydim: slice(-5, 5)})
        area = ds["area"].sel({xdim: slice(170, 210), ydim: slice(-5, 5)})
    else:
        print(f"Unknown region: {region}")
    tos = tos.weighted(area).mean((xdim, ydim))
    obs_enso_ts[region] = tos.load()

### Timeseries plots

In [None]:
# abstract out the dimension names here
xdim = "xh"
ydim = "yh"
tvar = "tos"
areavar = "areacello"

In [None]:
# Loop over groups and regions to extract timeseries

varname = "tos"
varobj = diag.varmap[varname]

enso_ts = {}

for region in enso_region:
    enso_ts[region] = {}
    for group in diag.groups:
        ds = group.datasets[varobj]
        if region == "nino12":
            tos = geoslice(ds[varname], x=(-90, -80), y=(-10, 0))
            area = geoslice(ds[areavar], x=(-90, -80), y=(-10, 0))
        elif region == "nino3":
            tos = geoslice(ds[varname], x=(-150, -90), y=(-5, 5))
            area = geoslice(ds[areavar], x=(-150, -90), y=(-5, 5))
        elif region == "nino34":
            tos = geoslice(ds[varname], x=(-170, -120), y=(-5, 5))
            area = geoslice(ds[areavar], x=(-170, -120), y=(-5, 5))
        elif region == "nino4":
            tos = geoslice(ds[varname], x=(-190, -150), y=(-5, 5))
            area = geoslice(ds[areavar], x=(-190, -150), y=(-5, 5))
        else:
            print(f"Unknown region: {region}")
        tos = tos.weighted(area).mean((xdim, ydim))
        enso_ts[region][group] = tos.load()

In [None]:
def plot_oni(ts, detrend=True, ax=None):

    # detrend, deseaon, and calculate the monthly anomalies and their stddev
    arr = ml.trend.linear_detrend(ts, mode="correct")
    ac = ml.util.annual_cycle(arr)
    ac = list(ac.values) * int(len(arr) / 12)

    # 3-month triangle filter
    arr = arr.rolling(time=3, center=True).mean()
    anom = arr - ac

    # Create masks for significant anomalies that persist 3+ months
    pos_mask = anom >= 0.5
    neg_mask = anom <= -0.5

    # Find consecutive runs of 5+ months
    pos_persistent = pos_mask & (pos_mask.rolling(time=5, center=True).sum() >= 5)
    neg_persistent = neg_mask & (neg_mask.rolling(time=5, center=True).sum() >= 5)

    # Count events (transitions from False to True mark new events)
    pos_events = (
        ((pos_persistent == True) & (pos_persistent.shift(time=1) == False))
        .sum()
        .values
    )
    neg_events = (
        ((neg_persistent == True) & (neg_persistent.shift(time=1) == False))
        .sum()
        .values
    )

    tax = anom.time.values

    if ax is None:
        fig, ax = plt.subplots(figsize=(18, 4))

    ax.plot(tax, anom, color="black", linewidth=0.5)

    ax.fill_between(tax, 0.5, anom, where=pos_persistent, color="red", alpha=0.7)
    ax.fill_between(tax, -0.5, anom, where=neg_persistent, color="blue", alpha=0.7)

    plt.axhline(0, color="black", linewidth=0.5)
    ax.set_xlim(tax[0], tax[-1])

    return (ax, tax, pos_events, neg_events)

In [None]:
esnb.nbtools.setup_plots()

ngroups = len(diag.groups) + 1
figsize = (esnb.nbtools.FULL_PAGE, esnb.nbtools.SINGLE_COLUMN * ngroups)
subplots = (ngroups, 1)

fig = plt.figure(figsize=figsize, dpi=150)
for n in range(0, len(diag.groups) + 1):
    if n == 0:
        arr = obs_enso_ts["nino34"]
        name = "NOAA_ERSST v2"
        group = None
    else:
        group = diag.groups[n - 1]
        arr = enso_ts["nino34"][group]
        name = group.name

    ax = plt.subplot(*subplots, n + 1)
    ax, tax, pos_events, neg_events = plot_oni(arr, ax=ax)
    ax.text(0, 1.03, f"ONI (Nino3.4) - {name}", transform=ax.transAxes, fontsize=9)

    pos_per_dec = round((pos_events / (len(tax) / 12.0)) * 10, 2)
    neg_per_dec = round((neg_events / (len(tax) / 12.0)) * 10, 2)

    ax.text(
        1.0,
        1.06,
        f"Positive Events: {pos_events} ({pos_per_dec} / decade)",
        transform=ax.transAxes,
        fontsize=7,
        ha="right",
    )
    ax.text(
        1.0,
        1.015,
        f"Negative Events: {neg_events} ({neg_per_dec} / decade)",
        transform=ax.transAxes,
        fontsize=7,
        ha="right",
    )

    if group is not None:
        group.add_metric(f"oni_events", ("positive", float(pos_events)))
        group.add_metric(f"oni_events", ("negative", float(neg_events)))
        group.add_metric(f"oni_events", ("positive_per_dec", float(pos_per_dec)))
        group.add_metric(f"oni_events", ("negative_per_dec", float(neg_per_dec)))
        group.add_metric(f"oni_events", ("nmonths", int(len(tax))))

### Wavelet Analysis

In [None]:
import xwavelet as xw

In [None]:
def scale_line_widths(ax, scale_factor):
    """
    Scale all line widths in a matplotlib axis by a given factor.

    Parameters:
    ax : matplotlib.axes.Axes
        The axis object containing the plots
    scale_factor : float
        Factor to multiply all line widths by
    """
    # Scale line widths for regular line plots
    for line in ax.get_lines():
        current_width = line.get_linewidth()
        line.set_linewidth(current_width * scale_factor)

    # Scale line widths for contour plots
    for collection in ax.collections:
        # Handle LineCollection objects (used by contour plots)
        if hasattr(collection, "get_linewidths"):
            current_widths = collection.get_linewidths()
            if current_widths is not None:
                # LineCollection can have array of widths or single width
                if hasattr(current_widths, "__iter__") and len(current_widths) > 1:
                    new_widths = [w * scale_factor for w in current_widths]
                else:
                    new_widths = current_widths * scale_factor
                collection.set_linewidths(new_widths)

In [None]:
figsize = (esnb.nbtools.FULL_PAGE, np.ceil((ngroups + 1) / 2) * 2.5)
subplots = (int(np.ceil((ngroups + 2) / 2)), 2)

axes = []
fig = plt.figure(figsize=figsize, dpi=200)
for n in range(0, len(diag.groups) + 1):
    ax = plt.subplot(*subplots, n + 1)
    if n == 0:
        arr = obs_enso_ts["nino34"]
        name = "NOAA_ERSST v2"
        group = None
    else:
        group = groups[n - 1]
        arr = enso_ts["nino34"][group]
        name = group

    result = xw.Wavelet(arr, scaled=True)
    _ = result.density(ax=ax)
    scale_line_widths(ax, 0.5)
    ax.text(
        0,
        1.03,
        f"Wavelet Density (Nino3.4) - {name}",
        transform=ax.transAxes,
        fontsize=6,
    )
    ax.set_xlabel(None)

    plt.subplots_adjust(hspace=0.4, wspace=0.4)
    axes.append(ax)

esnb.nbtools.panel_letters(axes)

### Frequency Spectra

In [None]:
def adjust_recent_line(ax, color=None, linewidth=None, label=None):
    """
    Adjust the color, linewidth, and label of the most recently added line plot.

    Parameters:
    -----------
    ax : matplotlib.axes.Axes
        The axis handle containing the line plots
    color : str or tuple, optional
        New color for the line (e.g., 'red', 'blue', '#FF5733', (0.2, 0.4, 0.6))
    linewidth : float, optional
        New line width
    label : str, optional
        New label for the line

    Returns:
    --------
    matplotlib.lines.Line2D
        The modified line object
    """
    # Get all line objects from the axis
    lines = ax.get_lines()

    if not lines:
        raise ValueError("No line plots found on the axis")

    # Get the most recent line (last in the list)
    recent_line = lines[-1]

    # Apply modifications if parameters are provided
    if color is not None:
        recent_line.set_color(color)

    if linewidth is not None:
        recent_line.set_linewidth(linewidth)

    if label is not None:
        recent_line.set_label(label)

    # Refresh the plot to show changes
    ax.figure.canvas.draw_idle()

    return recent_line

In [None]:
figsize = (esnb.nbtools.SINGLE_COLUMN, esnb.nbtools.SINGLE_COLUMN * 1.7)

fig = plt.figure(figsize=figsize, dpi=200)
ax = plt.subplot(1, 1, 1)

for n in range(0, len(diag.groups) + 1):
    if n == 0:
        arr = obs_enso_ts["nino34"]
        name = "NOAA_ERSST v2"
        group = None
        color = "black"
        linewidth = 1.5
    else:
        group = groups[n - 1]
        arr = enso_ts["nino34"][group]
        name = group
        color = group.plot_color
        linewidth = 1.0

    result = xw.Wavelet(arr, scaled=True)
    result.spectrum(ax=ax)
    adjust_recent_line(ax, color=color, linewidth=linewidth, label=name)

    ax.grid(True, linewidth=0.2)
    plt.legend(loc=4, fontsize=6)

    ax.text(
        0.0, 1.02, "Frequency Spectra - Nino3.4 SST", fontsize=8, transform=ax.transAxes
    )

### Part 4: Write metrics file

In [None]:
diag.write_metrics()