# ENSO Monitoring Notebook

J. Krasting -- NOAA/GFDL

In [None]:
# Development mode: constantly refreshes module code
%load_ext autoreload
%autoreload 2

## Framework Code and Diagnostic Setup

In [None]:
import esnb
from esnb import NotebookDiagnostic, RequestedVariable, CaseGroup2
from esnb.sites.gfdl import call_dmget

In [None]:
%%time

# Define a mode (leave "prod" for now)
mode = "prod"

# Verbosity
verbose = True

# Give your diagnostic a name and a short description
diag_name = "ENSO Monitoring"
diag_desc = "Basic diagnostics of ENSO / tropical SST variability"

# Define what variables you would like to analyze. The first entry is the
# variable name and the second entry is the realm (post-processing dir).
#   (By default, monthly timeseries data will be loaded. TODO: add documentation
#    on how to select different frequencies, multiple realms to search, etc.)
variables = [
    RequestedVariable("tos", "ocean_month"),
]

# Optional: define runtime settings or options for your diagnostic
user_options = {"enso_region": ["nino12", "nino3", "nino34", "nino4"]}

# Initialize the diagnostic with its name, description, vars, and options
diag = NotebookDiagnostic(diag_name, diag_desc, variables=variables, **user_options)

# Define the groups of experiments to analyze. Provide a single dora id for one experiment
# or a list of IDs to aggregate multiple experiments into one; e.g. historical+future runs
groups = [
    CaseGroup2("cm5-9", date_range=("0041-01-01", "0060-12-31")),
    CaseGroup2("odiv-1", date_range=("0041-01-01", "0060-12-31")),
    CaseGroup2(
        [1188, 1243],
        "time",
        name="ESM4 Historical + Future",
        date_range=("1993-01-01", "2022-12-31"),
    ),
]

# Combine the experiments with the diag request and determine what files need to be loaded:
diag.resolve(groups)

In [None]:
# Print a list of file paths
# This cell and the markdown cell that follows are necessary to run this notebook
# Interactively on Dora
_ = [print(x) for x in diag.files]

<i>(The files above are necessary to run the diagnostic.)</i>

In [None]:
# Check to see the dmget status before calling "open"
call_dmget(diag.files,status=True)

In [None]:
# Load the data as xarray datasets
diag.open()

## Begin the User Diagnostic Code

#### Load Modules

In [None]:
import os

import cftime
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

In [None]:
# Ocean trick - call momgrid on ocean model data to get the static info

import momgrid as mg
from momgrid.geoslice import geoslice

os.environ["MOMGRID_WEIGHTS_DIR"] = "/nbhome/jpk/grid_weights"

for ds in diag._datasets:
    ds.replace(mg.Gridset(ds.dataset).data)

#### Custom functions

In [None]:
def is_overlap_time(x, y, tdim="time"):
    start_x, end_x = (x[tdim].values[0], x[tdim].values[-1])
    start_y, end_y = (y[tdim].values[0], y[tdim].values[-1])
    return (end_x > start_y) | (end_y < start_x)

### Part 1: Timeseries plots

In this section, we will make timeseries plots of monthly SST in each of the ENSO regions.

In [None]:
# abstract out the dimension names here
xdim = "xh"
ydim = "yh"
tvar = "tos"
areavar = "areacello"

In [None]:
# get the custom variable from the diagnostic settings
enso_region = diag.diag_vars.get("enso_region", None)

In [None]:
# setup a local dictionary to average SST for
variable = diag.variables[0]   # first RequestedVariable object
varname = variable.varname  # varname resolves to "tos"
my_data_dict = {}

for group in diag.groups:
    ds = group.datasets[variable]
    tos_by_region = {}
    enso_region = [] if enso_region is None else enso_region
    for region in enso_region:
        if region == "nino12":
            tos = geoslice(ds[varname], x=(-90, -80), y=(-10, 0))
            area = geoslice(ds[areavar], x=(-90, -80), y=(-10, 0))
        elif region == "nino3":
            tos = geoslice(ds[varname], x=(-150, -90), y=(-5, 5))
            area = geoslice(ds[areavar], x=(-150, -90), y=(-5, 5))
        elif region == "nino34":
            tos = geoslice(ds[varname], x=(-170, -120), y=(-5, 5))
            area = geoslice(ds[areavar], x=(-170, -120), y=(-5, 5))
        elif region == "nino4":
            tos = geoslice(ds[varname], x=(-190, -150), y=(-5, 5))
            area = geoslice(ds[areavar], x=(-190, -150), y=(-5, 5))
        else:
            print(f"Unknown region: {region}")
        tos = tos.weighted(area).mean((xdim, ydim))
        tos_by_region[region] = tos.load()
    my_data_dict[group.name] = tos_by_region

In [None]:
for region in enso_region:
    fig = plt.figure(figsize=(6, 3), dpi=150)
    ax = plt.subplot(1, 1, 1)

    # dict to save plotted arrays
    plotted = {}

    varname = "tos"
    for n, group in enumerate(groups):
        da = my_data_dict[group.name][region]

        if n == 0:
            ax.plot(da.time, da, color=f"C{n}", label=group.name)
            plotted[ax] = da
        else:
            for ax, v in list(plotted.items()):
                if is_overlap_time(v, da):
                    ax.plot(da.time, da, color=f"C{n}", label=group.name)
                else:
                    _ax = ax.twiny()
                    _ax.plot(da.time, da, color=f"C{n}", label=group.name)
                    plotted[_ax] = da

                    # Adjust the position of the second x-axis to be below the main plot
                    _ax.spines["bottom"].set_position(("outward", 25))
                    _ax.xaxis.set_ticks_position("bottom")
                    _ax.xaxis.set_label_position("bottom")
                    _ax.spines["bottom"].set_visible(True)
                    _ax.spines["top"].set_visible(False)

                # save scalar metrics
                group.add_metric(f"{region}_timeseries", ("mean", float(da.mean())))
                group.add_metric(f"{region}_timeseries", ("stddev", float(da.std())))

    for ax in plotted.keys():
        ax.grid(True)
        #ax.legend()

    ax.text(0, 1.02, f"Monthly SST Nino Region: {region}", transform=ax.transAxes)

In [None]:
diag.metrics

### Part 2: Maps 

In [None]:
import cartopy.crs as ccrs

# import momlevel.utils for calendar tools
from momlevel import util

In [None]:
import cartopy.crs as ccrs
import matplotlib.patches as mpatches


def add_nino_boxes(ax):
    nino_regions = {
        "Nino 1+2": {"lon": (-90, -80), "lat": (-10, 0)},
        "Nino 3": {"lon": (-150, -90), "lat": (-5, 5)},
        "Nino 3.4": {"lon": (-170, -120), "lat": (-5, 5)},
        "Nino 4": {"lon": (-160, -150), "lat": (-5, 5)},
    }

    # Loop through Niño regions and add boxes
    for name, bounds in nino_regions.items():
        lon_min, lon_max = bounds["lon"]
        lat_min, lat_max = bounds["lat"]
        width = lon_max - lon_min
        height = lat_max - lat_min
        rect = mpatches.Rectangle(
            (lon_min, lat_min),
            width,
            height,
            linewidth=0.5,
            edgecolor="red",
            facecolor="none",
            transform=ccrs.PlateCarree(),
        )
        ax.add_patch(rect)

In [None]:
tropical_sst_dict = {}

title = "Annual Mean Climatology"

vmin = 16
vmax = 32

# cartopy map projection
projection = ccrs.PlateCarree(central_longitude=-160)

for group in groups:
    ds = group.datasets[variable]
    ds = geoslice(ds, x=(-270, -70), y=(-25, 25))

    fig = plt.figure(figsize=(6, 2.5), dpi=200)
    ax = plt.subplot(1, 1, 1, projection=projection, facecolor="gray")
    da = util.annual_average(ds.tos).mean("time", keep_attrs=True)
    cb = da.plot.pcolormesh(
        ax=ax,
        x="geolon",
        y="geolat",
        transform=ccrs.PlateCarree(),
        add_colorbar=False,
        vmin=vmin,
        vmax=vmax,
    )
    ax.set_title(None)

    label = da.attrs.get("long_name", "")  # default to empty string if missing
    units = da.attrs.get("units", "")
    if units:
        label = f"{label} [{units}]"

    # add horizontal colorbar
    cbar = fig.colorbar(cb, ax=ax, orientation="horizontal", pad=0.1, shrink=0.5)
    cbar.set_label(label, fontsize=8)
    cbar.ax.tick_params(labelsize=7)

    add_nino_boxes(ax)

    # get name from the first case in each group
    name = group.name
    ax.text(0.0, 1.05, name, fontsize=7, transform=ax.transAxes)
    ax.text(
        1.0, 1.05, title, ha="right", style="italic", fontsize=5, transform=ax.transAxes
    )

In [None]:
# Regress the tropical SST anomalies on the global SST anomalies

from momlevel import trend

def regress(y, x):
    x_mean = x.mean()
    y_mean = y.mean()
    slope = ((x - x_mean) * (y - y_mean)).sum() / ((x - x_mean) ** 2).sum()
    return slope    

In [None]:
# plotting options

vmin = -1.5
vmax = 1.5
cmap = "RdBu_r"
title = "Linear Regressions (Local + Nino3.4) of Annual Anomalies"

for group in diag.groups:
    ds = group.datasets[variable]
    da = ds[varname]
    da = trend.linear_detrend(da)
    da_clim = da.mean("time")
    da_anom = da - da_clim

    da_trop = ds[varname]
    da_trop = geoslice(da_trop, x=(-170, -120), y=(-5, 5))
    da_trop = da_trop.weighted(da_trop.coords[areavar]).mean((xdim, ydim))
    da_trop_clim = da_trop.mean("time")
    da_trop = da_trop - da_trop_clim

    regression_map = xr.apply_ufunc(
        regress,
        da_anom,
        da_trop,
        input_core_dims=[["time"], ["time"]],
        vectorize=True,
        dask="parallelized",
        output_dtypes=[float],
        dask_gufunc_kwargs={"allow_rechunk": True},
    )

    fig = plt.figure(dpi=200)
    ax = plt.subplot(1, 1, 1, projection=ccrs.Orthographic(-100, 20), facecolor="gray")
    cb = regression_map.plot.pcolormesh(
        ax=ax,
        x="geolon",
        y="geolat",
        transform=ccrs.PlateCarree(),
        add_colorbar=False,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
    )

    # add horizontal colorbar
    cbar = fig.colorbar(cb, ax=ax, orientation="horizontal", pad=0.1, shrink=0.5)
    cbar.set_label(label, fontsize=8)
    cbar.ax.tick_params(labelsize=7)

    # get name from the first case in each group
    name = group.name
    ax.text(0.0, 1.07, name, fontsize=7, transform=ax.transAxes)
    ax.text(
        0.0, 1.03, title, ha="left", style="italic", fontsize=5, transform=ax.transAxes
    )

### Part 3:  Seasonal / Annual Cycle Plots

In [None]:
vmin = -2
vmax = 2

title = "Anomaly Relative to the Annual Mean (5S to 5N)"

for group in diag.groups:
    fig = plt.figure(figsize=(4, 4), dpi=150)
    ax = plt.subplot(1, 1, 1)
    ds = group.datasets[variable]
    ds = geoslice(ds, x=(-270, -70), y=(-5, 5))
    da = ds[varname]
    da = da.weighted(da.coords[areavar]).mean(ydim)
    lon = ds.geolon.mean(ydim)
    da = util.annual_cycle(da)
    da = da - da.mean("time")
    cb = ax.pcolormesh(lon, np.arange(1, 13), da, cmap="RdBu_r", vmin=vmin, vmax=vmax)
    ax.set_yticks(np.arange(1, 13))
    ax.set_yticklabels(["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"])

    # add horizontal colorbar
    cbar = fig.colorbar(cb, ax=ax, orientation="horizontal", pad=0.1, shrink=0.5)
    cbar.set_label(label, fontsize=8)
    cbar.ax.tick_params(labelsize=7)

    # get name from the first case in each group
    ax.text(0.0, 1.07, group.name, fontsize=7, transform=ax.transAxes)
    ax.text(
        0.0, 1.03, title, ha="left", style="italic", fontsize=5, transform=ax.transAxes
    )

### Part 4: Write metrics file

In [None]:
diag.write_metrics()