In [6]:
import xarray as xr
import numpy as np 
import scipy
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
import cartopy.crs as ccrs
from sklearn.decomposition import PCA
import pandas as pd
import string
from src import config_cesm
from src.utils import util_cesm
from src.models.diagnostics import roll_metric
reference_grid = util_cesm.generate_sps_grid()
CNAMES = ["input2", "input3a", "input3b", "input3c", "input3d", "input4"]

def add_hatching(ax, significance_mask, x_edges, y_edges, hatch='///', edgecolor='k'):
    Ny, Nx = significance_mask.shape

    for i in range(Ny):
        for j in range(Nx):
            if significance_mask[i, j]:
                rect = mpatches.Rectangle(
                    (x_edges[j], y_edges[i]),
                    x_edges[j+1] - x_edges[j],
                    y_edges[i+1] - y_edges[i],
                    hatch=hatch,
                    fill=False,
                    edgecolor=edgecolor,
                    linewidth=0
                )
                ax.add_patch(rect)

def plot_markers(ax, exceeds_persistence, x_centers, y_centers):
    for i in range(6):
        for j in range(12):
            if not exceeds_persistence[i,j]:
                ax.plot(x_centers[j], y_centers[i], '.k', markersize=4) 

In [7]:
acc = {}
acc_agg = {}
significance_ds = {}

for cname in CNAMES:
    acc_agg[cname] = xr.open_dataset(
        os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_{cname}", "diagnostics/acc_agg.nc")
    )["acc"]
    acc[cname] = xr.open_dataset(
        os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_{cname}", "diagnostics/acc.nc")
    )["acc"]

for cname in CNAMES:
    if cname == "input2":
        continue

    significance_ds[cname] = xr.open_dataset(
        os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, f"confidence_intervals/exp1_input2_exp1_{cname}_acc.nc")
    )


## Plot comparison of ACC (anomaly correlation coefficient)

Each scorecard shows the ACC for predicting SIC at different months (x-axis) and lead times (y-axis). The results are aggregated over a test set of 4 CESM ensemble member historical simulations and 5 separate neural network initializations for each configuration.

In [None]:
P_VALUE_CUTOFF = 0.01

import matplotlib.patches as mpatches
import string

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9,5), sharey=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_input2 = acc_agg["input2"].mean("nn_member_id")
input_configs = ["input3a", "input3b", "input3c", "input3d", "input4"]
input_added = ["sst", "slp", "z500", "t2m", "all"]

cax2 = axs[0,0].pcolormesh(x, y, acc_input2, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"ACC (SIC only)")
axs[0,0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

axs = axs.flatten()
for i, ax in enumerate(axs[1:]):
    input_config = input_configs[i]
    cax = ax.pcolormesh(x, y, 100 * (acc_agg[input_config].mean("nn_member_id") - acc_input2) / acc_input2, 
            cmap='RdBu', shading='flat', vmin=-15, vmax=15)
    significance_mask = (significance_ds[input_config].p_value > P_VALUE_CUTOFF).astype(int)
    add_hatching(ax, significance_mask, x, y)
    ax.set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
    ax.set_yticks(y_centers, labels=np.arange(1,7,1))
    ax.set_title(f"$\Delta$ ACC: add {input_added[i]}")

axs[3].set_ylabel("Lead time")

cbar_ax = fig.add_axes([0.95, 0.2, 0.01, 0.6])
cbar_ax2 = fig.add_axes([1.05, 0.2, 0.01, 0.6])

plt.colorbar(cax, cax=cbar_ax2, label=r'Percent change (%)', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax, label=r'ACC', orientation='vertical')

plt.subplots_adjust(hspace=0.5)

panel_labels = list(string.ascii_lowercase)

for i, ax in enumerate(axs):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.1, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

plt.savefig("figures/cesm_new/exp1_ACC_percent.pdf", bbox_inches='tight')

In [None]:
P_VALUE_CUTOFF = 0.01

import matplotlib.patches as mpatches
import string

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9,5), sharey=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_input2 = acc_agg["input2"].mean("nn_member_id")
input_configs = ["input3a", "input3b", "input3c", "input3d", "input4"]
input_added = ["sst", "slp", "z500", "t2m", "all"]

cax2 = axs[0,0].pcolormesh(x, y, acc_input2, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"ACC (SIC only)")
axs[0,0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

axs = axs.flatten()
for i, ax in enumerate(axs[1:]):
    input_config = input_configs[i]
    cax = ax.pcolormesh(x, y, acc_agg[input_config].mean("nn_member_id") - acc_input2, 
            cmap='RdBu', shading='flat', vmin=-0.05, vmax=0.05)
    significance_mask = (significance_ds[input_config].p_value > P_VALUE_CUTOFF).astype(int)
    add_hatching(ax, significance_mask, x, y)
    ax.set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
    ax.set_yticks(y_centers, labels=np.arange(1,7,1))
    ax.set_title(f"$\Delta$ ACC: add {input_added[i]}")

axs[3].set_ylabel("Lead time")

cbar_ax = fig.add_axes([0.95, 0.2, 0.01, 0.6])
cbar_ax2 = fig.add_axes([1.05, 0.2, 0.01, 0.6])

plt.colorbar(cax, cax=cbar_ax2, label=r'$\Delta$ ACC (unitless)', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax, label=r'ACC', orientation='vertical')

plt.subplots_adjust(hspace=0.5)

panel_labels = list(string.ascii_lowercase)

for i, ax in enumerate(axs):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.1, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

plt.savefig("figures/cesm_new/exp1_ACC_absolute.pdf", bbox_inches='tight')

## Plot RMSE comparison

In [10]:
rmse = {}
rmse_agg = {}
significance_ds_rmse = {}

for cname in CNAMES:
    rmse_agg[cname] = xr.open_dataset(
        os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_{cname}", "diagnostics/rmse_agg.nc")
    )["rmse"]
    rmse[cname] = xr.open_dataset(
        os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_{cname}", "diagnostics/rmse.nc")
    )["rmse"]

for cname in CNAMES:
    if cname == "input2":
        continue

    significance_ds_rmse[cname] = xr.open_dataset(
        os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, f"confidence_intervals/exp1_input2_exp1_{cname}_rmse.nc")
    )


In [None]:
P_VALUE_CUTOFF = 0.01

import matplotlib.patches as mpatches
import string

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9,5), sharey=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

rmse_input2 = rmse_agg["input2"].mean("nn_member_id")
input_configs = ["input3a", "input3b", "input3c", "input3d", "input4"]
input_added = ["sst", "slp", "z500", "t2m", "all"]

cax2 = axs[0,0].pcolormesh(x, y, rmse_input2, cmap='Spectral_r', shading='flat', vmin=0.02, vmax=0.08)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"RMSE (SIC only)")
axs[0,0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

axs = axs.flatten()
for i, ax in enumerate(axs[1:]):
    input_config = input_configs[i]
    cax = ax.pcolormesh(x, y, rmse_agg[input_config].mean("nn_member_id") - rmse_input2, 
            cmap='RdBu_r', shading='flat', vmin=-0.005, vmax=0.005)
    significance_mask = (significance_ds_rmse[input_config].p_value > P_VALUE_CUTOFF).astype(int)
    add_hatching(ax, significance_mask, x, y)
    ax.set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
    ax.set_yticks(y_centers, labels=np.arange(1,7,1))
    ax.set_title(f"$\Delta$ RMSE: add {input_added[i]}")

axs[3].set_ylabel("Lead time")

cbar_ax = fig.add_axes([0.95, 0.2, 0.01, 0.6])
cbar_ax2 = fig.add_axes([1.05, 0.2, 0.01, 0.6])

plt.colorbar(cax, cax=cbar_ax2, label=r'$\Delta$ RMSE', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax, label=r'RMSE', orientation='vertical')

plt.subplots_adjust(hspace=0.5)

panel_labels = list(string.ascii_lowercase)

for i, ax in enumerate(axs):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.1, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

plt.savefig("figures/cesm_new/exp1_RMSE_absolute.pdf", bbox_inches='tight')

In [None]:
P_VALUE_CUTOFF = 0.01

import matplotlib.patches as mpatches
import string

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9,5), sharey=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

rmse_input2 = rmse_agg["input2"].mean("nn_member_id")
input_configs = ["input3a", "input3b", "input3c", "input3d", "input4"]
input_added = ["sst", "slp", "z500", "t2m", "all"]

cax2 = axs[0,0].pcolormesh(x, y, rmse_input2, cmap='Spectral_r', shading='flat', vmin=0.02, vmax=0.08)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"RMSE (SIC only)")
axs[0,0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

axs = axs.flatten()
for i, ax in enumerate(axs[1:]):
    input_config = input_configs[i]
    cax = ax.pcolormesh(x, y, 100 * (rmse_agg[input_config].mean("nn_member_id") - rmse_input2) / rmse_input2, 
            cmap='RdBu_r', shading='flat', vmin=-10, vmax=10)
    significance_mask = (significance_ds_rmse[input_config].p_value > P_VALUE_CUTOFF).astype(int)
    add_hatching(ax, significance_mask, x, y)
    ax.set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
    ax.set_yticks(y_centers, labels=np.arange(1,7,1))
    ax.set_title(f"$\Delta$ RMSE: add {input_added[i]}")

axs[3].set_ylabel("Lead time")

cbar_ax = fig.add_axes([0.95, 0.2, 0.01, 0.6])
cbar_ax2 = fig.add_axes([1.05, 0.2, 0.01, 0.6])

plt.colorbar(cax, cax=cbar_ax2, label=r'Percent change RMSE', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax, label=r'RMSE', orientation='vertical')

plt.subplots_adjust(hspace=0.5)

panel_labels = list(string.ascii_lowercase)

for i, ax in enumerate(axs):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.1, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

plt.savefig("figures/cesm_new/exp1_RMSE_percent.pdf", bbox_inches='tight')

## Permute and predict

In [13]:
def plot_perm(ax, arr, title, lag_max, vmin=-10, vmax=10, fontsize=7, cmap="Reds_r", cutoff=6):
    A = np.flip(arr, axis=0)
    im = ax.pcolormesh(A, cmap=cmap, vmin=vmin, vmax=vmax)

    x_centers = np.arange(0.5, 6.5)
    ax.set_xticks(x_centers, labels=np.arange(1, 7))

    y_centers = np.arange(0.5, lag_max + 0.5)
    ax.set_yticks(y_centers, labels=np.arange(lag_max, 0, -1))

    for j in range(6):
        for k in range(lag_max):
            val = A[k, j]
            ax.text(
                j + 0.5, k + 0.5, f"{val:.2f}",
                ha="center", va="center",
                fontsize=fontsize,
                color="white" if np.abs(val) > cutoff else "black"
            )

    ax.set_title(title)
    return im

rmse_orig = xr.open_dataset(
    os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_input4", "diagnostics/rmse_agg.nc")
)["rmse"].mean(("month", "nn_member_id"))

var_names = ["sst", "psl", "geopotential", "t2m", "icefrac"]
label_var_names = ["sst", "slp", "z500", "t2m", "sic"]
var_lags = {"sst": 6, "psl": 6, "geopotential": 6, "t2m": 6, "icefrac": 12}


rmse_permuted = {
    var: np.zeros((var_lags[var], 6)) for var in var_names
}
for i, var in enumerate(var_names):
    for lag in range(1, var_lags[var] + 1):
        label_var_name = f"{var_names[i]} [{lag}]"
        rmse_config = xr.open_dataset(
            os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_input4/diagnostics/rmse_permute_{var}_lag{lag}_agg.nc")
        )["rmse"]
        
        percent_change = 100 * (rmse_config.mean(("month", "nn_member_id")) - rmse_orig) / rmse_orig
        rmse_permuted[var][lag-1, :] = percent_change


acc_orig = xr.open_dataset(
    os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_input4", "diagnostics/acc_agg.nc")
)["acc"].mean(("month", "nn_member_id"))

acc_permuted = {
    var: np.zeros((var_lags[var], 6)) for var in var_names
}
for i, var in enumerate(var_names):
    for lag in range(1, var_lags[var] + 1):
        label_var_name = f"{var_names[i]} [{lag}]"
        acc_config = xr.open_dataset(
            os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_input4/diagnostics/acc_permute_{var}_lag{lag}_agg.nc")
        )["acc"]
        
        percent_change = 100 * (acc_config.mean(("month", "nn_member_id")) - acc_orig) / acc_orig
        acc_permuted[var][lag-1, :] = percent_change

In [None]:
fig = plt.figure(figsize=(10, 4))
gs = gridspec.GridSpec(2, 3, width_ratios=[1, 1, 1], wspace=0.35, hspace=0.45)

axs_left = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1]), fig.add_subplot(gs[1, 0]), fig.add_subplot(gs[1, 1])]
ax_ice = fig.add_subplot(gs[:, 2])

vmin, vmax = -10, 0
ims = []

for i, ax in enumerate(axs_left):
    var = var_names[i]
    im = plot_perm(
        ax,
        acc_permuted[var],
        label_var_names[i],
        var_lags[var],
        vmin=vmin, vmax=vmax
    )
    ims.append(im)

    if i in (2, 3):
        ax.set_xlabel("Lead")
    if i in (0, 2):
        ax.set_ylabel("Lag")

im_ice = plot_perm(
    ax_ice,
    acc_permuted["icefrac"],
    label_var_names[var_names.index("icefrac")],
    var_lags["icefrac"],
    vmin=vmin, vmax=vmax
)
ax_ice.set_xlabel("Lead")
ax_ice.set_ylabel("Lag")

panel_labels = string.ascii_lowercase
for i, ax in enumerate(axs_left + [ax_ice]):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.12, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

cbar = fig.colorbar(im_ice, ax=axs_left + [ax_ice], fraction=0.03, pad=0.03, extend="min")
cbar.set_label("% change in ACC")
plt.savefig("figures/cesm_new/exp1_permute_and_predict_ACC.pdf", bbox_inches='tight')


In [None]:
fig = plt.figure(figsize=(10, 4))
gs = gridspec.GridSpec(2, 3, width_ratios=[1, 1, 1], wspace=0.35, hspace=0.45)

axs_left = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1]), fig.add_subplot(gs[1, 0]), fig.add_subplot(gs[1, 1])]
ax_ice = fig.add_subplot(gs[:, 2])

vmin, vmax = 0, 20
ims = []

for i, ax in enumerate(axs_left):
    var = var_names[i]
    im = plot_perm(
        ax,
        rmse_permuted[var],
        label_var_names[i],
        var_lags[var],
        vmin=vmin, vmax=vmax, cmap="Reds", cutoff=12
    )
    ims.append(im)

    if i in (2, 3):
        ax.set_xlabel("Lead")
    if i in (0, 2):
        ax.set_ylabel("Lag")

im_ice = plot_perm(
    ax_ice,
    rmse_permuted["icefrac"],
    label_var_names[var_names.index("icefrac")],
    var_lags["icefrac"],
    vmin=vmin, vmax=vmax, cmap="Reds", cutoff=12
)
ax_ice.set_xlabel("Lead")
ax_ice.set_ylabel("Lag")

panel_labels = string.ascii_lowercase
for i, ax in enumerate(axs_left + [ax_ice]):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.12, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

cbar = fig.colorbar(im_ice, ax=axs_left + [ax_ice], fraction=0.03, pad=0.03, extend="max")
cbar.set_label("% change in RMSE")
plt.savefig("figures/cesm_new/exp1_permute_and_predict_rmse.pdf", bbox_inches='tight')


## Effect of normalization strategy on SST configuration skill

In [None]:
from src.experiment_configs.exp1_inputs import input3a_dev
from src.experiment_configs.exp1_inputs import input3a_std

cdicts = {
    "input3a_dev": input3a_dev, 
    "input3a_std": input3a_std, 
}

cdict = load_globals(input3a_dev)
climatology_broadcast = get_broadcast_climatology(cdict, "test")
targets = load_targets(cdict, "test", add_climatology_to_anomaly=False)

acc = {}
acc_agg = {}

num_nn_ens_members = 3
for key, config in cdicts.items():
    print(f"computing ACC for {key}")
    cdict = load_globals(config) 
    pred = load_model_predictions(cdict, nn_ens_avg=False, climatology_broadcasted=None, 
                                    add_climatology_to_anomaly=False)

    acc_temp_list = []
    for i in range(num_nn_ens_members): 
        acc_temp = calculate_acc(pred.isel(nn_member_id=i), 
                            targets, dim=("x","y"), aggregate=False)
        acc_temp_list.append(acc_temp)
    acc[key] = xr.concat(acc_temp_list, dim="nn_member_id")
    acc_agg[key] = aggregate_acc(acc[key], dim=("x","y"))

In [None]:
save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")

for config in acc:
    acc[config].to_netcdf(os.path.join(save_dir, f"ACC_{config}.nc"))
    acc_agg[config].to_netcdf(os.path.join(save_dir, f"ACC_agg_{config}.nc"))

In [None]:
# load saved acc 
acc = {}
acc_agg = {}

save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")
for key in ["input2", "input3a", "input3a_dev", "input3a_std"]:
    acc[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_{key}.nc"))["__xarray_dataarray_variable__"]
    acc_agg[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_agg_{key}.nc"))["__xarray_dataarray_variable__"]

In [None]:
significance_ds = {}

for k in ["input3a", "input3a_dev", "input3a_std"]: 
    print(f"Computing bootstrap significance test for {k}")

    significance_ds[k] = bootstrap_acc_significance(acc["input2"], acc[k])

for k, ds in significance_ds.items():
    significance_ds[k] = roll_acc(ds)

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(7,6))

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_zscore_norm = acc_agg["input3a_std"].mean("nn_member_id")
acc_minmax_norm = acc_agg["input3a"].mean("nn_member_id")
acc_nonorm = acc_agg["input3a_dev"].mean("nn_member_id")
acc_sic_only = acc_agg["input2"].mean("nn_member_id")

cax = axs[0,0].pcolormesh(x, y, acc_zscore_norm, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"Sea ice + sst (z-score norm)")

axs[1,0].pcolormesh(x, y, acc_minmax_norm, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[1,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[1,0].set_ylabel("Lead time")
axs[1,0].set_title(f"Sea ice + sst (min-max norm)")

axs[2,0].pcolormesh(x, y, acc_nonorm, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[2,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[2,0].set_ylabel("Lead time")
axs[2,0].set_title(f"Sea ice + sst (no norm)")

cax2 = axs[0,1].pcolormesh(x, y, 100 * (acc_zscore_norm - acc_sic_only) / acc_sic_only,
                    cmap='RdBu', shading='flat', vmin=-15, vmax=15)
axs[0,1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,1].set_ylabel("Lead time")
axs[0,1].set_title(f"ACC diff from sea ice only")
significance_mask = (significance_ds["input3a_std"].p_value > P_VALUE_CUTOFF).astype(int)
add_hatching(axs[0,1], significance_mask, x, y)

axs[1,1].pcolormesh(x, y, 100 * (acc_minmax_norm - acc_sic_only) / acc_sic_only,
                    cmap='RdBu', shading='flat', vmin=-15, vmax=15)
axs[1,1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[1,1].set_ylabel("Lead time")
axs[1,1].set_title(f"ACC diff from sea ice only")
significance_mask = (significance_ds["input3a"].p_value > P_VALUE_CUTOFF).astype(int)
add_hatching(axs[1,1], significance_mask, x, y)

axs[2,1].pcolormesh(x, y, 100 * (acc_nonorm - acc_sic_only) / acc_sic_only,
                    cmap='RdBu', shading='flat', vmin=-15, vmax=15)
axs[2,1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[2,1].set_ylabel("Lead time")
axs[2,1].set_title(f"ACC diff from sea ice only")
significance_mask = (significance_ds["input3a_dev"].p_value > P_VALUE_CUTOFF).astype(int)
add_hatching(axs[2,1], significance_mask, x, y)

for i in range(3):
    for j in range(2):
        axs[i,j].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

cbar_ax = fig.add_axes([0.1, -0.05, 0.35, 0.02])
cbar_ax2 = fig.add_axes([0.6, -0.05, 0.35, 0.02])

plt.colorbar(cax, cax=cbar_ax, label=r'ACC', orientation='horizontal')
plt.colorbar(cax2, cax=cbar_ax2, label=r'Percent change (%)', orientation='horizontal')

plt.tight_layout()
plt.savefig("figures/cesm/exp1_ACC_diff_sst_norm.jpg", dpi=300, bbox_inches='tight')

## noise experiment

In [None]:
from src.experiment_configs.exp1_inputs import input_noise

cdict = load_globals(input_noise)
climatology_broadcast = get_broadcast_climatology(cdict, "test")
targets = load_targets(cdict, "test", add_climatology_to_anomaly=False)

num_nn_ens_members = 5
pred = load_model_predictions(cdict, nn_ens_avg=False, climatology_broadcasted=None, 
                                add_climatology_to_anomaly=False)

acc_temp_list = []
for i in range(num_nn_ens_members): 
    acc_temp = calculate_acc(pred.isel(nn_member_id=i), 
                        targets, dim=("x","y"), aggregate=False)
    acc_temp_list.append(acc_temp)
acc_noise = xr.concat(acc_temp_list, dim="nn_member_id")
acc_agg_noise = aggregate_acc(acc_noise, dim=("x","y"))

save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")
acc_noise.to_netcdf(os.path.join(save_dir, f"ACC_input_noise.nc"))
acc_agg_noise.to_netcdf(os.path.join(save_dir, f"ACC_agg_input_noise.nc"))

In [None]:
# load saved acc 
acc = {}
acc_agg = {}

save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")
for key in ["input2", "input_noise"]:
    acc[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_{key}.nc"))["__xarray_dataarray_variable__"]
    acc_agg[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_agg_{key}.nc"))["__xarray_dataarray_variable__"]

In [None]:
significance_ds = {}

for k in ["input_noise",]: 
    print(f"Computing bootstrap significance test for {k}")

    significance_ds[k] = bootstrap_acc_significance(acc["input2"], acc[k])

for k, ds in significance_ds.items():
    significance_ds[k] = roll_acc(ds)

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(7, 2))

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_noise = acc_agg["input_noise"].mean("nn_member_id")
acc_sic_only = acc_agg["input2"].mean("nn_member_id")

cax = axs[0].pcolormesh(x, y, acc_noise, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0].set_ylabel("Lead time")
axs[0].set_title(f"Sea ice + 6 channels of noise")

cax2 = axs[1].pcolormesh(x, y, 100 * (acc_noise - acc_sic_only) / acc_sic_only,
                    cmap='RdBu', shading='flat', vmin=-15, vmax=15)
axs[1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[1].set_ylabel("Lead time")
axs[1].set_title(f"ACC diff from sea ice only")
significance_mask = (significance_ds["input_noise"].p_value > P_VALUE_CUTOFF).astype(int)
add_hatching(axs[1], significance_mask, x, y)

for i in range(2):
    axs[i].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

cbar_ax = fig.add_axes([0.1, -0.05, 0.35, 0.05])
cbar_ax2 = fig.add_axes([0.6, -0.05, 0.35, 0.05])

plt.colorbar(cax, cax=cbar_ax, label=r'ACC', orientation='horizontal')
plt.colorbar(cax2, cax=cbar_ax2, label=r'Percent change (%)', orientation='horizontal')

plt.tight_layout()
plt.savefig("figures/cesm/exp1_ACC_diff_noise.jpg", dpi=300, bbox_inches='tight')

## Temporal dependence of ACC

### EOF analysis

In [2]:
test_members = ['r2i1251p1f1', 'r2i1281p1f1', 'r2i1301p1f1', 'r3i1041p1f1']
data_dir = '/scratch/users/yucli/cesm_data_processed/data_pairs/seaice_plus_all'

inputs_da = []
targets_da = []
for member_id in test_members: 
    inputs_da.append(xr.open_dataset(os.path.join(data_dir, f"inputs_member_{member_id}.nc")))
    targets_da.append(xr.open_dataset(os.path.join(data_dir, f"targets_member_{member_id}.nc")))

inputs_da = xr.concat(inputs_da, dim="member_id")
targets_da = xr.concat(targets_da, dim="member_id")

In [3]:
input_vars = ["psl_lag1",] #"icefrac_lag1", "geopotential_lag1", "t2m_lag1"]

eofs_dict = {}
pcs_dict = {}
var_exp_dict = {}
n_components = 4

for input_var in input_vars:
    stacked = inputs_da["data"].sel(channel=input_var).stack(samples=("member_id", "start_prediction_month"))
    nx, ny, nsamples = stacked.shape

    anom_matrix = (stacked - stacked.mean(("x","y"))).values.reshape(nx * ny, nsamples).T

    pca = PCA(n_components=n_components)
    pca.fit(anom_matrix)
    pcs = pca.transform(anom_matrix)
    eofs_flat = pca.components_ 
    eofs = eofs_flat.reshape(n_components, ny, nx)

    eofs_dict[input_var] = xr.DataArray(
        eofs,
        dims=("mode", "y", "x"),
        coords={"mode": np.arange(n_components), "y": stacked.y, "x": stacked.x},
        name="EOF"
    )
    
    samples_index = stacked["samples"]
    pcs_dict[input_var] = xr.DataArray(
        pcs,
        dims=("samples", "mode"),
        coords={"samples": samples_index, "mode": np.arange(n_components)},
        name="PC"
    ).unstack("samples")

    var_exp_dict[input_var] = xr.DataArray(
        pca.explained_variance_ratio_,
        dims=("mode",),
        coords={"mode": np.arange(n_components)},
        name="VarianceExplained"
    )



In [None]:
plt.figure(figsize=(5,4))
plt.pcolormesh(eofs_dict["psl_lag1"].isel(mode=0), cmap="RdBu_r")
plt.title("Sea level pressure (leading mode)")
plt.colorbar()
# plt.savefig("figures/illustrations/sam.jpg",dpi=300,bbox_inches='tight')

In [None]:
pred = {
    "input2": xr.open_dataset("/scratch/users/yucli/sicpred_model_predictions/exp1_input2/UNetRes3_best_predictions.nc")["predictions"],
    "input3b": xr.open_dataset("/scratch/users/yucli/sicpred_model_predictions/exp1_input3b/UNetRes3_best_predictions.nc")["predictions"],
    "targets": targets_da["data"]
}

def linregress(x, y): 
    slope, _, r, _, _ = scipy.stats.linregress(x, y)

    autocorr_x = np.corrcoef(x[:-1], x[1:])[0,1]
    autocorr_y = np.corrcoef(y[:-1], y[1:])[0,1]

    n_eff = len(x) * (1 - autocorr_x*autocorr_y) / (1 + autocorr_x*autocorr_y)
    t_stat = r * np.sqrt((n_eff - 2) / (1 - r**2))
    p_value = 2 * scipy.stats.t.sf(np.abs(t_stat), df = n_eff - 2)

    return slope, p_value

# compute the regression coefficients per grid point of SIC 
# regressed onto the September SAM-like index at lead times
# up to 5 months. For lead = 5, this corresponds to February
# SIC predictions
slopes = {}
pvalues = {}

for input_config, ds in pred.items():
    for i in range(1, 6):
        regressor = pcs_dict["psl_lag1"].isel(mode=0).stack(sample=("start_prediction_month", "member_id"))
        regressor = regressor.where(regressor.start_prediction_month.dt.month == 10, drop=True)

        subset = ds.where(ds.start_prediction_month.dt.month == 10, drop=True)
        subset = subset.stack(sample=("start_prediction_month", "member_id")).sel(lead_time=i)

        slope, p_value = xr.apply_ufunc(
            linregress,
            regressor, subset,
            input_core_dims=[["sample"], ["sample"]],
            output_core_dims=[[], []],
            vectorize=True,
        )

        slopes[f"{input_config}_lead{i}"] = slope
        pvalues[f"{input_config}_lead{i}"] = p_value


In [49]:
def compute_pfdr(p_values, alpha_fdr = 0.02):
    p_vals_sorted = np.sort(p_values.values.flatten())
    n = np.sum(~np.isnan(p_vals_sorted))
    mask = p_vals_sorted[:n] <= (np.arange(1, n+1) / n) * alpha_fdr
    return p_vals_sorted[np.arange(n)[mask][-1]]

significance_masks = {}
for k, v in pvalues.items():
    if k[:5] == "input":
        masks = np.zeros((5, 80, 80))
        for i in range(5):
            pfdr = compute_pfdr(v.isel(nn_member_id=i))
            mask = v.isel(nn_member_id=i) < pfdr
            masks[i, :, :] = mask
        significance_masks[k] = np.all(masks, axis=0)
    else:
        pfdr = compute_pfdr(v)
        significance_masks[k] = v < pfdr


In [None]:
fig, axs = plt.subplots(3, 5, figsize=(10, 6), sharex=True, sharey=True)

land_mask = xr.open_dataset("/oak/stanford/groups/earlew/yuchen/cesm_data/grids/icefrac_land_mask.nc")["mask"]
x, y = reference_grid.x, reference_grid.y
months = ["Oct", "Nov", "Dec", "Jan", "Feb"]
scaling = pcs_dict["psl_lag1"].isel(mode=0).std()
for i in range(5):

    im = axs[0, i].pcolormesh(x, y, -100 * scaling * slopes[f"targets_lead{i+1}"], cmap="seismic_r", vmin=-8, vmax=8)
    axs[0, i].contourf(x, y, significance_masks[f"targets_lead{i+1}"], levels=[0.5, 1.5], hatches=["..."], colors='none', )

    axs[1, i].pcolormesh(x, y, -100 * scaling * slopes[f"input2_lead{i+1}"].mean("nn_member_id"), cmap="seismic_r", vmin=-8, vmax=8)
    axs[1, i].contourf(x, y, significance_masks[f"input2_lead{i+1}"], levels=[0.5, 1.5], hatches=["..."], colors='none')

    axs[2, i].pcolormesh(x, y, -100 * scaling * slopes[f"input3b_lead{i+1}"].mean("nn_member_id"), cmap="seismic_r", vmin=-8, vmax=8)
    axs[2, i].contourf(x, y, significance_masks[f"input3b_lead{i+1}"], levels=[0.5, 1.5], hatches=["..."], colors='none')

    axs[0,i].set_title(f"Lead {i+1} ({months[i]})")

for ax in axs.flatten():
    ax.contour(x, y, land_mask, levels=[0.5, 1.5], colors='grey', linewidths=1)
    ax.contourf(x, y, land_mask, levels=[0.5, 1.5], colors='white')
    ax.set_xticks([])
    ax.set_yticks([])

axs[0,0].set_ylabel("Truth")
axs[1,0].set_ylabel("SIC only model")
axs[2,0].set_ylabel("SIC + slp model")

cbar_ax = fig.add_axes([0.92, 0.2, 0.01, 0.6])
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.set_label("Linear change in SIC (%) per stdev PC1")


panel_labels = list(string.ascii_lowercase)
for i, ax in enumerate(axs.flatten()):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(0.05, 0.85), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

plt.savefig("figures/cesm_new/exp1_SAM_regression.png", dpi=300, bbox_inches='tight')


In [None]:
_, ax = plt.subplots(1, 1, figsize=(4, 4), subplot_kw={'projection': ccrs.SouthPolarStereo()})

ax.contour(lon, lat, pvalues["targets_lead5"].data < 0.05, transform=ccrs.PlateCarree())

In [None]:
time = pcs_dict["icefrac_lag1"].start_prediction_month
month = 10

fig, axs = plt.subplots(figsize=(8,5), nrows=3, ncols=6, sharex=True)

cmap = plt.get_cmap("plasma").copy()

psl_pc = np.abs(pcs_dict["psl_lag1"].isel(mode=0).where(time.dt.month == month, drop=True).values.flatten())
acc_input2 = acc["input2"].mean("nn_member_id").where(time.dt.month == month, drop=True)
acc_input4 = acc["input4"].mean("nn_member_id").where(time.dt.month == month, drop=True)

months = ['Oct', 'Nov', 'Dec', 'Jan', 'Feb', 'Mar']
for i in range(6):
    axs[0,i].hist2d(psl_pc, acc_input2.isel(lead_time=i).values.flatten(),
                    bins=[15,15], density=True, cmap="cubehelix_r")

    axs[1,i].hist2d(psl_pc, acc_input4.isel(lead_time=i).values.flatten(), 
                    bins=[15,15], density=True, cmap="cubehelix_r")

    axs[2,i].hist2d(psl_pc, (acc_input4 - acc_input2).isel(lead_time=i).values.flatten(), 
                    bins=[15,19], density=True, cmap="cubehelix_r")

    axs[0,i].set_ylim([-0.25, 1])
    axs[1,i].set_ylim([-0.25, 1])
    axs[2,i].set_ylim([-0.4, 0.4])

    for j in range(3):
        axs[j,i].set_xlim([0, 20])
        axs[j,i].set_xticks([0, 10, 20])
        
        if i > 0:
            axs[j,i].set_yticklabels([])
        
        axs[j,i].grid(color='0.3')
    
    axs[0,0].set_ylabel("ACC (input1)") #this uses the same nomenclature as the paper 
    axs[1,0].set_ylabel("ACC (input3)") 
    axs[2,0].set_ylabel("ACC difference") 
    axs[2,i].set_xlabel(rf"abs($PC_1^{{\mathrm{{psl}}}}$)")
    axs[0,i].set_title(f"Lead {i+1} ({months[i]})", fontsize=11)

plt.savefig("figures/cesm/exp1_SAM_ACC_joint_dist_sep_init.jpg", bbox_inches='tight', dpi=300)

In [None]:
var_exp_dict["psl_lag1"]