# Sea Surface Temperature Bias

Instructions: Run this notebook on the GFDL system using the shared O-div Python environment.  

For more details on getting started, see the [Unified Notebook Template Documentation](https://docs.google.com/document/d/1cY-yWoEOANqsDICZWNFNkxbwUHjEXBL63mL6aBbVyyM/edit?tab=t.0#heading=h.hoyr4umbujp6)

In [None]:
# required by the esnb notebook framework
import esnb
from esnb import CaseGroup2, NotebookDiagnostic, RequestedVariable, nbtools
from esnb.sites.gfdl import call_dmget, convert_to_momgrid

In [None]:
# required by the diagnostic
import os

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import momgrid
import momlevel as ml
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
from matplotlib.colors import BoundaryNorm, ListedColormap
from momgrid.geoslice import geoslice

# Add the environment variable for MOMGrid weights
os.environ["MOMGRID_WEIGHTS_DIR"] = "/nbhome/jpk/grid_weights"

In [None]:
# define requested variables

variables = [
    RequestedVariable("tos", "ocean_month"),
]

In [None]:
# setup the diagnostic
mode = "prod"
verbose = True
diag_name = "SST_bias_NOAA_OISSTv2"
diag_desc = "Sea surface termperature bias analysis"
user_options = {"plot_region": ["global"], "plot": "documentation"}

diag = NotebookDiagnostic(diag_name, diag_desc, variables=variables, **user_options)

In [None]:
# define dora ids for experiments to analyze
# ids = ["odiv-516", "odiv-384", "odiv-319", "odiv-290"]
ids = ["odiv-516", "odiv-290"]

# initialize CaseGroup objects
groups = [CaseGroup2(x, date_range=("1991-01-01", "2020-12-31")) for x in ids]

# set experiment names (mostly for labelling)
# names = [x.name for x in groups]
# names = ["b11", "b05", "b03", "b01"]
names = ["b11", "b01"]

In [None]:
# determine what files to load
diag.resolve(groups)

In [None]:
# Print a list of file paths
# This cell and the markdown cell that follows are necessary to run this notebook
# Interactively on Dora
_ = [print(x) for x in diag.files]

<i>(The files above are necessary to run the diagnostic.)</i>

In [None]:
# Check to see the dmget status before calling "open"
call_dmget(diag.files, status=True)
call_dmget(diag.files)

In [None]:
# Load the data as xarray datasets
diag.open()

## Main Diagnostic

In [None]:
# make an empty list to house all of the figures that are generated
all_figs = []

In [None]:
# Load matplotlib settings for figures
nbtools.setup_plots()

In [None]:
# Load SST observations
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
dsobs = {
    "om4": momgrid.Gridset(
        xr.open_dataset(
            "/archive/jpk/datasets/observations/NOAA-OISST/processed/NOAA_OISST_v2_climo_199101_202012_v20250722_OM4.nc",
            decode_times=time_coder,
        )
    ).data,
    "om5": momgrid.Gridset(
        xr.open_dataset(
            "/archive/jpk/datasets/observations/NOAA-OISST/processed/NOAA_OISST_v2_climo_199101_202012_v20250722_OM5.nc",
            decode_times=time_coder,
        )
    ).data,
}

obs_om5 = dsobs["om5"].tos
obs_om4 = dsobs["om4"].tos

In [None]:
# Convert to momgrid and make a dict of names for conveinence
convert_to_momgrid(diag)
exps = dict(zip(names, diag.groups))

In [None]:
# loop over experiments and make the annual mean
annual = {}
for k, v in exps.items():
    ds = v.datasets[variables[0]]
    arr = ds["tos"]
    arr = ml.util.annual_average(arr).load().mean("time")
    arr.attrs["model"] = ds.model
    annual[k] = arr

In [None]:
# A function to plot an individual map panel

def plot_map_panel(ax, arr, obs=None, cmap=None, norm=None, label="", stats=True):
    if obs is not None:
        plotvar = arr - obs
    else:
        plotvar = arr

    cb = ax.pcolormesh(
        plotvar.geolon,
        plotvar.geolat,
        plotvar,
        transform=ccrs.PlateCarree(),
        cmap=cmap,
        norm=norm,
    )
    ax.set_title(f"{label} SST Bias (Model minus NOAA OISSTv2)", fontsize=6)

    if stats:
        orig_stats, stats_str = nbtools.calculate_stats(arr, obs, arr.areacello)
        stats = {k: f"{round(float(v),2)}" for k, v in orig_stats.items()}
        del [stats["rsquared"]]
        stats_str = str("  ").join([f"{k}={v}" for k, v in stats.items()])
        ax.text(
            0.5, -0.1, stats_str, ha="center", style="italic", transform=ax.transAxes
        )
    else:
        stats = {}

    return (ax, cb, stats)

In [None]:
reg_settings = {
    "global": {
        "projection": ccrs.Robinson(central_longitude=-160),
        "levels": (-4.0, 4.25, 0.25),
        "xrange": None,
        "yrange": None,
        "hspace": -0.4,
    },
    "arctic": {
        "projection": ccrs.NorthPolarStereo(),
        "levels": (-2.0, 2.125, 0.125),
        "xrange": (-298, 61),
        "yrange": (60.0, 91.0),
        "hspace": 0.2,
    },
    "southern_ocean": {
        "projection": ccrs.SouthPolarStereo(),
        "levels": (-2.0, 2.125, 0.125),
        "xrange": (-300, 60),
        "yrange": (-60.0, -91.0),
        "hspace": 0.2,
    },
    "nw_pacific": {
        "projection": ccrs.Miller(central_longitude=-200),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-250, -150),
        "yrange": (25, 60),
        "hspace": 0.0,
    },
    "trop_indopac": {
        "projection": ccrs.Miller(central_longitude=-180),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-280, -80),
        "yrange": (-23, 23),
        "hspace": -0.4,
    },
    "australia": {
        "projection": ccrs.Miller(central_longitude=-180),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-300, -170),
        "yrange": (0, -60),
        "hspace": -0.4,
    },
    "north_atlantic": {
        "projection": ccrs.Miller(central_longitude=-60),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-80, 0),
        "yrange": (30, 75),
        "hspace": -0.3,
    },
    "south_atlantic": {
        "projection": ccrs.Miller(central_longitude=-25),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-75, 25),
        "yrange": (-20, -60),
        "hspace": -0.4,
    },
    "california": {
        "projection": ccrs.Miller(central_longitude=-115),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-130, -100),
        "yrange": (18, 46),
        "hspace": 0.2,
    },
    "bengula": {
        "projection": ccrs.Miller(central_longitude=5),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-10, 20),
        "yrange": (35, -30),
        "hspace": 0.2,
    },
    "peru": {
        "projection": ccrs.Miller(central_longitude=-85),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-110, -60),
        "yrange": (0, -40),
        "hspace": 0.2,
    },
    "caribbean": {
        "projection": ccrs.Miller(central_longitude=-80),
        "levels": (-4.5, 4.75, 0.25),
        "xrange": (-110, -60),
        "yrange": (5, 35),
        "hspace": -0.3,
    },
}

In [None]:
# Loop over regions and make maps

for reg in reg_settings.keys():

    print(f"Processing region: {reg}")
    settings = reg_settings[reg]

    # make plots for regions
    if reg in diag.diag_vars["plot_region"]:
        nexps = len(diag.groups)
        figsize, subplot = nbtools.get_figsize_subplots(nexps)

        results = []
        fig = plt.figure(figsize=figsize, dpi=200)
        cmap, norm, boundaries = nbtools.gen_levs_and_cmap(*settings["levels"])
        projection = settings["projection"]

        for n, k in enumerate(annual.keys()):
            arr = annual[k]
            obs = obs_om4 if "om4" in arr.model else obs_om5
            if settings["xrange"] is not None or settings["yrange"] is not None:
                arr = momgrid.geoslice.geoslice(
                    arr, x=settings["xrange"], y=settings["yrange"]
                )
                obs = momgrid.geoslice.geoslice(
                    obs, x=settings["xrange"], y=settings["yrange"]
                )
            ax = plt.subplot(*subplot, n + 1, projection=projection, facecolor="gray")
            res = plot_map_panel(ax, arr, obs, cmap, norm, label=k.upper())
            results.append(res)

        axes, cbs, stats = zip(*results)

        plt.subplots_adjust(hspace=settings["hspace"])

        cbar = nbtools.bottom_colorbar(
            fig, cbs[0], orientation="horizontal", extend="both"
        )
        cbar.set_label("Sea Surface Temperature Bias [degC]")

        # add letter labels for each panel
        nbtools.panel_letters(axes, -0.12, 1.17)

        # append figure to the complete list of figures
        all_figs.append(fig)

    # for regions that are not plotted, still calculate the statistics
    else:
        results = []
        for n, k in enumerate(annual.keys()):
            arr = annual[k]
            obs = obs_om4 if "om4" in arr.model else obs_om5
            if settings["xrange"] is not None or settings["yrange"] is not None:
                arr = momgrid.geoslice.geoslice(
                    arr, x=settings["xrange"], y=settings["yrange"]
                )
                obs = momgrid.geoslice.geoslice(
                    obs, x=settings["xrange"], y=settings["yrange"]
                )

            res = nbtools.calculate_stats(arr, obs, arr.areacello)
            results.append(res)

        stats, _ = zip(*results)

    # register statistics
    for n, grp in enumerate(diag.groups):
        for k, v in stats[n].items():
            metric = (k, float(v))
            grp.add_metric(reg, metric)

### Seasonal Cycle

In [None]:
# Load SST observations
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
dsobs = {
    "om4": momgrid.Gridset(
        xr.open_dataset(
            "/archive/jpk/datasets/observations/NOAA-OISST/processed/NOAA_OISST_v2_anncycle_199101_202012_v20250722_OM4.nc",
            decode_times=time_coder,
        )
    ).data,
    "om5": momgrid.Gridset(
        xr.open_dataset(
            "/archive/jpk/datasets/observations/NOAA-OISST/processed/NOAA_OISST_v2_anncycle_199101_202012_v20250722_OM5.nc",
            decode_times=time_coder,
        )
    ).data,
}

obs_om5 = dsobs["om5"].tos
obs_om4 = dsobs["om4"].tos

In [None]:
# loop over experiments and make the annual cycle
anncyc = {}
for k, v in exps.items():
    ds = v.datasets[variables[0]]
    arr = ds["tos"]
    arr = ml.util.annual_cycle(arr).load()
    arr.attrs["model"] = ds.model
    anncyc[k] = arr

In [None]:
# function to plot a single seasonal cycle panel


def plot_panel(ax, arr, obs, vmin=-1, vmax=1, label=""):
    time = np.arange(1, 13)
    area = arr.areacello
    lat = arr.geolat.mean("xh")

    arr = arr.assign_coords({"time": time})
    arr = arr.weighted(area).mean("xh")

    obs = obs.assign_coords({"time": time})
    obs = obs.weighted(area).mean("xh")

    diff = arr - obs

    dpm = np.array(
        [31.0, 28, 31.0, 30.0, 31.0, 30.0, 31.0, 31.0, 30.0, 31.0, 30.0, 31.0]
    )
    coslat = np.cos(np.deg2rad(lat))
    x, y = np.meshgrid(dpm, coslat)
    wgt = (diff * 0.0) + (x * y).T
    wgt = wgt.fillna(0.0)

    levels = np.arange(-1, 1.1, 0.1)
    cb = ax.contourf(
        diff.time, lat, diff.T, levels=levels, cmap="RdBu_r", extend="both"
    )
    ax.set_title(f"{label} SST Bias (Model minus NOAA OISSTv2)")

    ax.set_xticks(np.arange(1, 13))
    ax.set_xticklabels(["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"])

    orig_stats, stats_str = nbtools.calculate_stats(arr, obs, wgt)
    stats = {k: f"{round(float(v),2)}" for k, v in orig_stats.items()}
    del [stats["rsquared"]]
    stats_str = str("  ").join([f"{k}={v}" for k, v in stats.items()])
    ax.text(0.5, -0.15, stats_str, ha="center", style="italic", transform=ax.transAxes)
    ax.grid(True, color="k", linestyle=":", linewidth=0.5)

    ax.set_ylim(-78, None)

    return (ax, cb, stats)

In [None]:
nexps = len(diag.groups)
figsize, subplot = nbtools.get_figsize_subplots(nexps)

results = []
fig = plt.figure(figsize=figsize, dpi=200)

for n, k in enumerate(annual.keys()):
    arr = anncyc[k]
    obs = obs_om4 if "om4" in arr.model else obs_om5
    ax = plt.subplot(*subplot, n + 1, facecolor="gray")
    res = plot_panel(ax, arr, obs, label=k.upper())
    results.append(res)

plt.subplots_adjust(hspace=0.3)

axes, cbs, stats = zip(*results)

cbar = nbtools.bottom_colorbar(fig, cbs[0], orientation="horizontal", extend="both")
cbar.set_label("Sea Surface Temperature Bias [degC]")

# add letter labels for each panel
nbtools.panel_letters(axes)

# register statistics
for n, grp in enumerate(diag.groups):
    for k, v in stats[n].items():
        metric = (k, float(v))
        grp.add_metric("zonal_seas_cycle", metric)

all_figs.append(fig)

### Maps by Season

In [None]:
seasons = {
    "DJF": [11, 0, 1],
    "MAM": [2, 3, 4],
    "JJA": [5, 6, 7],
    "SON": [8, 9, 10],
}

In [None]:
for season in seasons.keys():

    nexps = len(diag.groups)
    figsize, subplot = nbtools.get_figsize_subplots(nexps)

    results = []
    fig = plt.figure(figsize=figsize, dpi=200)

    cmap, norm, boundaries = nbtools.gen_levs_and_cmap(-7, 7.5, 0.5)
    projection = ccrs.Robinson(central_longitude=-160)

    for n, k in enumerate(anncyc.keys()):
        arr = anncyc[k]
        obs = obs_om4 if "om4" in arr.model else obs_om5
        arr = arr.isel(time=seasons[season]).mean("time")
        obs = obs.isel(time=seasons[season]).mean("time")
        ax = plt.subplot(*subplot, n + 1, projection=projection, facecolor="gray")
        res = plot_map_panel(ax, arr, obs, cmap, norm, label=f"{k.upper()} {season}")
        results.append(res)

    axes, cbs, stats = zip(*results)

    cbar = nbtools.bottom_colorbar(fig, cbs[0], orientation="horizontal", extend="both")
    cbar.set_label("Sea Surface Temperature Bias [degC]")

    # add letter labels for each panel
    nbtools.panel_letters(axes)

    # register statistics
    for n, grp in enumerate(diag.groups):
        for k, v in stats[n].items():
            metric = (k, float(v))
            grp.add_metric(f"{str(season).lower()}_bias", metric)

    all_figs.append(fig)

### SST Trend

In [None]:
# Load SST observations
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
dsobs = {
    "om4": momgrid.Gridset(
        xr.open_dataset(
            "/archive/jpk/datasets/observations/NOAA-OISST/processed/NOAA_OISST_v2_annual_199101_202012_v20250722_OM4.nc",
            decode_times=time_coder,
        )
    ).data,
    "om5": momgrid.Gridset(
        xr.open_dataset(
            "/archive/jpk/datasets/observations/NOAA-OISST/processed/NOAA_OISST_v2_annual_199101_202012_v20250722_OM5.nc",
            decode_times=time_coder,
        )
    ).data,
}

obs_om5 = dsobs["om5"].tos
obs_om4 = dsobs["om4"].tos

In [None]:
obs_slope_om5 = ml.trend.calc_linear_trend(obs_om5, time_units="yr")["tos_slope"].load()

In [None]:
# plot observed trend

fig = plt.figure(figsize=(nbtools.SINGLE_COLUMN, nbtools.SINGLE_COLUMN), dpi=200)
ax = plt.subplot(
    1, 1, 1, projection=ccrs.Robinson(central_longitude=-160), facecolor="gray"
)
cmap, norm, boundaries = nbtools.gen_levs_and_cmap(-0.1, 0.11, 0.01)
cb = ax.pcolormesh(
    obs_om5.geolon,
    obs_om5.geolat,
    obs_slope_om5,
    transform=ccrs.PlateCarree(),
    cmap=cmap,
    norm=norm,
)
ax.set_title("NOAA OISSTv2 1991-2020")
cbar = nbtools.bottom_colorbar(fig, cb, orientation="horizontal", extend="both")
cbar.set_label("Sea Surface Temperature Trend [degC yr-1]")

all_figs.append(fig)

In [None]:
# loop over experiments and make the annual cycle
modtrend = {}
for k, v in exps.items():
    ds = v.datasets[variables[0]]
    arr = ml.util.annual_average(ds["tos"])
    trend = ml.trend.calc_linear_trend(arr, time_units="yr")["tos_slope"].load()
    trend.attrs["model"] = ds.model
    trend = trend.assign_coords(
        {"geolon": ds.geolon, "geolat": ds.geolat, "areacello": ds.areacello}
    )
    modtrend[k] = trend

In [None]:
for reg in reg_settings.keys():

    print(f"Processing region: {reg}")
    settings = reg_settings[reg]

    # make plots for regions
    if reg in diag.diag_vars["plot_region"]:
        nexps = len(diag.groups)
        figsize, subplot = nbtools.get_figsize_subplots(nexps)

        results = []
        fig = plt.figure(figsize=figsize, dpi=200)
        cmap, norm, boundaries = nbtools.gen_levs_and_cmap(-0.1, 0.11, 0.01)
        projection = settings["projection"]

        for n, k in enumerate(annual.keys()):
            arr = modtrend[k]
            if settings["xrange"] is not None or settings["yrange"] is not None:
                arr = momgrid.geoslice.geoslice(
                    arr, x=settings["xrange"], y=settings["yrange"]
                )
            ax = plt.subplot(*subplot, n + 1, projection=projection, facecolor="gray")
            ax, cb, _ = plot_map_panel(
                ax, arr, cmap=cmap, norm=norm, label=k.upper(), stats=False
            )
            ax.set_title(f"{k.upper()} SST Trend 1991-2020")
            avg_trend = arr.weighted(arr.areacello).mean(("yh", "xh"))
            stat = {"trend": float(avg_trend)}
            results.append((ax, cb, stat))

        axes, cbs, stats = zip(*results)

        cbar = nbtools.bottom_colorbar(fig, cb, orientation="horizontal", extend="both")
        cbar.set_label("Sea Surface Temperature Trend [degC yr-1]")

        all_figs.append(fig)

    # for regions that are not plotted, still calculate the statistics
    else:
        stats = []
        for n, k in enumerate(annual.keys()):
            arr = modtrend[k]
            if settings["xrange"] is not None or settings["yrange"] is not None:
                arr = momgrid.geoslice.geoslice(
                    arr, x=settings["xrange"], y=settings["yrange"]
                )
            avg_trend = arr.weighted(arr.areacello).mean(("yh", "xh"))
            stats.append({"trend": float(avg_trend)})

    # register statistics
    for n, grp in enumerate(diag.groups):
        for k, v in stats[n].items():
            metric = (k, float(v))
            grp.add_metric(f"{reg}_trend", metric)

### Process and synthesize metrics

In [None]:
data = diag.metrics
exps = list(data["RESULTS"]["Global"].keys())
metrics = list(data["RESULTS"]["Global"][exps[0]].keys())
trends = set([x for x in metrics if "_trend" in x])
seasonal = set(
    [x for x in metrics if "_bias" in x] + [x for x in metrics if "_seas_cycle" in x]
)
regions = set(metrics) - trends - seasonal

In [None]:
results = {}
for exp in exps:
    results[exp] = {}
    for reg in regions:
        x = data["RESULTS"]["Global"][exp][reg]["bias"]
        results[exp][reg] = x

pd.DataFrame(results)

In [None]:
results = {}
for exp in exps:
    results[exp] = {}
    for trend in trends:
        results[exp][trend] = data["RESULTS"]["Global"][exp][trend]["trend"]

pd.DataFrame(results)

In [None]:
results = {}
for exp in exps:
    results[exp] = {}
    for seas in seasonal:
        results[exp][seas] = data["RESULTS"]["Global"][exp][seas]["bias"]

pd.DataFrame(results)

### Write Metrics to File

In [None]:
diag.write_metrics("SST_metrics.json")

### Make a PowerPoint Presentation of Figures

In [None]:
nbtools.save_pptx(all_figs,"SST_bias.pptx")