In [1]:
import xarray as xr
import numpy as np 
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd
from src import config_cesm
from src.models.diagnostics import roll_metric

CNAMES = ["input2", "input3a", "input3b", "input3c", "input3d", "input4"]

def add_hatching(ax, significance_mask, x_edges, y_edges, hatch='///', edgecolor='k'):
    Ny, Nx = significance_mask.shape

    for i in range(Ny):
        for j in range(Nx):
            if significance_mask[i, j]:
                rect = mpatches.Rectangle(
                    (x_edges[j], y_edges[i]),
                    x_edges[j+1] - x_edges[j],
                    y_edges[i+1] - y_edges[i],
                    hatch=hatch,
                    fill=False,
                    edgecolor=edgecolor,
                    linewidth=0
                )
                ax.add_patch(rect)

def plot_markers(ax, exceeds_persistence, x_centers, y_centers):
    for i in range(6):
        for j in range(12):
            if not exceeds_persistence[i,j]:
                ax.plot(x_centers[j], y_centers[i], '.k', markersize=4) 

In [2]:
acc = {}
acc_agg = {}
significance_ds = {}

for cname in CNAMES:
    acc_agg[cname] = xr.open_dataset(
        os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_{cname}", "diagnostics/acc_agg.nc")
    )["acc"]
    acc[cname] = xr.open_dataset(
        os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_{cname}", "diagnostics/acc.nc")
    )["acc"]

for cname in CNAMES:
    if cname == "input2":
        continue

    significance_ds[cname] = xr.open_dataset(
        os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, f"confidence_intervals/exp1_input2_exp1_{cname}_acc.nc")
    )


## Plot comparison of ACC (anomaly correlation coefficient)

Each scorecard shows the ACC for predicting SIC at different months (x-axis) and lead times (y-axis). The results are aggregated over a test set of 4 CESM ensemble member historical simulations and 5 separate neural network initializations for each configuration.

In [None]:
save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")
os.makedirs(save_dir, exist_ok=True)
for key, _ in cdicts.items():
    acc[key].to_netcdf(os.path.join(save_dir, f"ACC_{key}.nc"))
    acc_agg[key].to_netcdf(os.path.join(save_dir, f"ACC_agg_{key}.nc"))


In [None]:
# load saved acc 
acc = {}
acc_agg = {}

save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")
for key in cdicts.keys():
    acc[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_{key}.nc"))["__xarray_dataarray_variable__"]
    acc_agg[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_agg_{key}.nc"))["__xarray_dataarray_variable__"]

#### Compute baseline forecasts

In [None]:
from src.models import models 

print(f"Loading persistence and climatology forecast...")
persistence_predictions = remove_climatology(models.anomaly_persistence(input2_cdict["DATA_SPLIT_SETTINGS"], None).predictions, climatology_broadcast)
climatology_predictions = remove_climatology(models.climatology_predictions(input2_cdict["DATA_SPLIT_SETTINGS"], None).predictions, climatology_broadcast)

print(f"Computing ACC")
acc_persist = calculate_acc(persistence_predictions, targets, dim=("x","y"), aggregate=True)
acc_clim = calculate_acc(climatology_predictions, targets, dim=("x","y"), aggregate=True)
print("done!")

In [None]:
acc_persist_not_agg = calculate_acc(persistence_predictions, targets, dim=("x","y"), aggregate=False)


### Statistical significance

TODO: try this significance test w.r.t. the non-transformed ACCs (i.e., don't apply the Fisher $z$-transformation)

In [None]:
significance_ds = {}

for k in cdicts.keys(): 
    print(f"Computing bootstrap significance test for {k}")

    # for the sea ice only config, test the difference with persistence
    if k == "input2":
        significance_ds[k] = bootstrap_acc_significance(acc["input2"], acc_persist_not_agg)

    # otherwise, the sea ice only config is the baseline
    else:
        significance_ds[k] = bootstrap_acc_significance(acc["input2"], acc[k])


In [None]:
for k, ds in significance_ds.items():
    significance_ds[k] = roll_acc(ds)

In [None]:

fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(4,4), sharex=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_input2 = acc_agg["input2"].mean("nn_member_id")
acc_input4 = acc_agg["input4"].mean("nn_member_id")

cax = axs[0].pcolormesh(x, y, acc_input2, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0].set_ylabel("Lead time")
axs[0].set_title(f"ACC for sea ice only")

cax2 = axs[1].pcolormesh(x, y, 100 * (acc_input4 - acc_input2) / acc_input2, cmap='RdBu', shading='flat', vmin=-10, vmax=10)
axs[1].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
axs[1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[1].set_ylabel("Lead time")
axs[1].set_title("ACC difference: add sst, psl, z500")

cbar_ax = fig.add_axes([1.04, 0.56, 0.03, 0.35])
cbar_ax2 = fig.add_axes([1.04, 0.10, 0.03, 0.35])
plt.colorbar(cax, cax=cbar_ax, label=r'ACC', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax2, label=r'Percent change (%)', orientation='vertical')
plt.tight_layout()
# plt.savefig("figures/cesm/exp1_ACC_diff_inputs_percent.jpg", dpi=300, bbox_inches='tight')

In [None]:
P_VALUE_CUTOFF = 0.05

import matplotlib.patches as mpatches
import string

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9,5), sharey=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_input2 = acc_agg["input2"].mean("nn_member_id")
input_configs = ["input3a", "input3b", "input3c", "input3d", "input4"]
input_added = ["sst", "slp", "z500", "t2m", "all"]

cax2 = axs[0,0].pcolormesh(x, y, acc_input2, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"ACC (SIC only)")
axs[0,0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

axs = axs.flatten()
for i, ax in enumerate(axs[1:]):
    input_config = input_configs[i]
    cax = ax.pcolormesh(x, y, 100 * (acc_agg[input_config].mean("nn_member_id") - acc_input2) / acc_input2, 
            cmap='RdBu', shading='flat', vmin=-15, vmax=15)
    significance_mask = (significance_ds[input_config].p_value > P_VALUE_CUTOFF).astype(int)
    add_hatching(ax, significance_mask, x, y)
    ax.set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
    ax.set_yticks(y_centers, labels=np.arange(1,7,1))
    ax.set_title(f"$\Delta$ ACC: add {input_added[i]}")

axs[3].set_ylabel("Lead time")

cbar_ax = fig.add_axes([0.95, 0.2, 0.01, 0.6])
cbar_ax2 = fig.add_axes([1.05, 0.2, 0.01, 0.6])

plt.colorbar(cax, cax=cbar_ax2, label=r'Percent change (%)', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax, label=r'ACC', orientation='vertical')

plt.subplots_adjust(hspace=0.5)

panel_labels = list(string.ascii_lowercase)

for i, ax in enumerate(axs):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.1, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

# plt.savefig("figures/cesm_new/exp1_ACC_percent.png", dpi=300, bbox_inches='tight')

In [None]:
P_VALUE_CUTOFF = 0.01

import matplotlib.patches as mpatches
import string

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9,5), sharey=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_input2 = acc_agg["input2"].mean("nn_member_id")
input_configs = ["input3a", "input3b", "input3c", "input3d", "input4"]
input_added = ["sst", "slp", "z500", "t2m", "all"]

cax2 = axs[0,0].pcolormesh(x, y, acc_input2, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"ACC (SIC only)")
axs[0,0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

axs = axs.flatten()
for i, ax in enumerate(axs[1:]):
    input_config = input_configs[i]
    cax = ax.pcolormesh(x, y, acc_agg[input_config].mean("nn_member_id") - acc_input2, 
            cmap='RdBu', shading='flat', vmin=-0.05, vmax=0.05)
    significance_mask = (significance_ds[input_config].p_value > P_VALUE_CUTOFF).astype(int)
    add_hatching(ax, significance_mask, x, y)
    ax.set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
    ax.set_yticks(y_centers, labels=np.arange(1,7,1))
    ax.set_title(f"$\Delta$ ACC: add {input_added[i]}")

axs[3].set_ylabel("Lead time")

cbar_ax = fig.add_axes([0.95, 0.2, 0.01, 0.6])
cbar_ax2 = fig.add_axes([1.05, 0.2, 0.01, 0.6])

plt.colorbar(cax, cax=cbar_ax2, label=r'$\Delta$ ACC (unitless)', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax, label=r'ACC', orientation='vertical')

plt.subplots_adjust(hspace=0.5)

panel_labels = list(string.ascii_lowercase)

for i, ax in enumerate(axs):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.1, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

# plt.savefig("figures/cesm_new/exp1_ACC_absolute.png", dpi=300, bbox_inches='tight')

## Plot RMSE comparison

In [None]:
rmse = {}
rmse_agg = {}
significance_ds_rmse = {}

for cname in CNAMES:
    rmse_agg[cname] = xr.open_dataset(
        os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_{cname}", "diagnostics/rmse_agg.nc")
    )["rmse"]
    rmse[cname] = xr.open_dataset(
        os.path.join(config_cesm.PREDICTIONS_DIRECTORY, f"exp1_{cname}", "diagnostics/rmse.nc")
    )["rmse"]

for cname in CNAMES:
    if cname == "input2":
        continue

    significance_ds_rmse[cname] = xr.open_dataset(
        os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, f"confidence_intervals/exp1_input2_exp1_{cname}_rmse.nc")
    )


In [None]:
P_VALUE_CUTOFF = 0.01

import matplotlib.patches as mpatches
import string

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9,5), sharey=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

rmse_input2 = rmse_agg["input2"].mean("nn_member_id")
input_configs = ["input3a", "input3b", "input3c", "input3d", "input4"]
input_added = ["sst", "slp", "z500", "t2m", "all"]

cax2 = axs[0,0].pcolormesh(x, y, rmse_input2, cmap='Spectral_r', shading='flat', vmin=0.02, vmax=0.08)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"RMSE (SIC only)")
axs[0,0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

axs = axs.flatten()
for i, ax in enumerate(axs[1:]):
    input_config = input_configs[i]
    cax = ax.pcolormesh(x, y, rmse_agg[input_config].mean("nn_member_id") - rmse_input2, 
            cmap='RdBu_r', shading='flat', vmin=-0.005, vmax=0.005)
    significance_mask = (significance_ds_rmse[input_config].p_value > P_VALUE_CUTOFF).astype(int)
    add_hatching(ax, significance_mask, x, y)
    ax.set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
    ax.set_yticks(y_centers, labels=np.arange(1,7,1))
    ax.set_title(f"$\Delta$ RMSE: add {input_added[i]}")

axs[3].set_ylabel("Lead time")

cbar_ax = fig.add_axes([0.95, 0.2, 0.01, 0.6])
cbar_ax2 = fig.add_axes([1.05, 0.2, 0.01, 0.6])

plt.colorbar(cax, cax=cbar_ax2, label=r'$\Delta$ RMSE', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax, label=r'RMSE', orientation='vertical')

plt.subplots_adjust(hspace=0.5)

panel_labels = list(string.ascii_lowercase)

for i, ax in enumerate(axs):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.1, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

plt.savefig("figures/cesm_new/exp1_RMSE_absolute.png", dpi=300, bbox_inches='tight')

In [None]:
P_VALUE_CUTOFF = 0.01

import matplotlib.patches as mpatches
import string

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9,5), sharey=True)

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

rmse_input2 = rmse_agg["input2"].mean("nn_member_id")
input_configs = ["input3a", "input3b", "input3c", "input3d", "input4"]
input_added = ["sst", "slp", "z500", "t2m", "all"]

cax2 = axs[0,0].pcolormesh(x, y, rmse_input2, cmap='Spectral_r', shading='flat', vmin=0.02, vmax=0.08)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"RMSE (SIC only)")
axs[0,0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

axs = axs.flatten()
for i, ax in enumerate(axs[1:]):
    input_config = input_configs[i]
    cax = ax.pcolormesh(x, y, 100 * (rmse_agg[input_config].mean("nn_member_id") - rmse_input2) / rmse_input2, 
            cmap='RdBu_r', shading='flat', vmin=-10, vmax=10)
    significance_mask = (significance_ds_rmse[input_config].p_value > P_VALUE_CUTOFF).astype(int)
    add_hatching(ax, significance_mask, x, y)
    ax.set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])
    ax.set_yticks(y_centers, labels=np.arange(1,7,1))
    ax.set_title(f"$\Delta$ RMSE: add {input_added[i]}")

axs[3].set_ylabel("Lead time")

cbar_ax = fig.add_axes([0.95, 0.2, 0.01, 0.6])
cbar_ax2 = fig.add_axes([1.05, 0.2, 0.01, 0.6])

plt.colorbar(cax, cax=cbar_ax2, label=r'Percent change RMSE', orientation='vertical')
plt.colorbar(cax2, cax=cbar_ax, label=r'RMSE', orientation='vertical')

plt.subplots_adjust(hspace=0.5)

panel_labels = list(string.ascii_lowercase)

for i, ax in enumerate(axs):
    ax.annotate(
        f"{panel_labels[i]})",
        xy=(0, 1), xycoords="axes fraction",
        xytext=(-0.1, 1.03), textcoords="axes fraction",
        ha="left", va="bottom",
        fontsize=11, fontweight="bold"
    )

plt.savefig("figures/cesm_new/exp1_RMSE_percent.png", dpi=300, bbox_inches='tight')

## Effect of normalization strategy on SST configuration skill

In [None]:
from src.experiment_configs.exp1_inputs import input3a_dev
from src.experiment_configs.exp1_inputs import input3a_std

cdicts = {
    "input3a_dev": input3a_dev, 
    "input3a_std": input3a_std, 
}

cdict = load_globals(input3a_dev)
climatology_broadcast = get_broadcast_climatology(cdict, "test")
targets = load_targets(cdict, "test", add_climatology_to_anomaly=False)

acc = {}
acc_agg = {}

num_nn_ens_members = 3
for key, config in cdicts.items():
    print(f"computing ACC for {key}")
    cdict = load_globals(config) 
    pred = load_model_predictions(cdict, nn_ens_avg=False, climatology_broadcasted=None, 
                                    add_climatology_to_anomaly=False)

    acc_temp_list = []
    for i in range(num_nn_ens_members): 
        acc_temp = calculate_acc(pred.isel(nn_member_id=i), 
                            targets, dim=("x","y"), aggregate=False)
        acc_temp_list.append(acc_temp)
    acc[key] = xr.concat(acc_temp_list, dim="nn_member_id")
    acc_agg[key] = aggregate_acc(acc[key], dim=("x","y"))

In [None]:
save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")

for config in acc:
    acc[config].to_netcdf(os.path.join(save_dir, f"ACC_{config}.nc"))
    acc_agg[config].to_netcdf(os.path.join(save_dir, f"ACC_agg_{config}.nc"))

In [None]:
# load saved acc 
acc = {}
acc_agg = {}

save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")
for key in ["input2", "input3a", "input3a_dev", "input3a_std"]:
    acc[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_{key}.nc"))["__xarray_dataarray_variable__"]
    acc_agg[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_agg_{key}.nc"))["__xarray_dataarray_variable__"]

In [None]:
significance_ds = {}

for k in ["input3a", "input3a_dev", "input3a_std"]: 
    print(f"Computing bootstrap significance test for {k}")

    significance_ds[k] = bootstrap_acc_significance(acc["input2"], acc[k])

for k, ds in significance_ds.items():
    significance_ds[k] = roll_acc(ds)

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(7,6))

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_zscore_norm = acc_agg["input3a_std"].mean("nn_member_id")
acc_minmax_norm = acc_agg["input3a"].mean("nn_member_id")
acc_nonorm = acc_agg["input3a_dev"].mean("nn_member_id")
acc_sic_only = acc_agg["input2"].mean("nn_member_id")

cax = axs[0,0].pcolormesh(x, y, acc_zscore_norm, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,0].set_ylabel("Lead time")
axs[0,0].set_title(f"Sea ice + sst (z-score norm)")

axs[1,0].pcolormesh(x, y, acc_minmax_norm, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[1,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[1,0].set_ylabel("Lead time")
axs[1,0].set_title(f"Sea ice + sst (min-max norm)")

axs[2,0].pcolormesh(x, y, acc_nonorm, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[2,0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[2,0].set_ylabel("Lead time")
axs[2,0].set_title(f"Sea ice + sst (no norm)")

cax2 = axs[0,1].pcolormesh(x, y, 100 * (acc_zscore_norm - acc_sic_only) / acc_sic_only,
                    cmap='RdBu', shading='flat', vmin=-15, vmax=15)
axs[0,1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0,1].set_ylabel("Lead time")
axs[0,1].set_title(f"ACC diff from sea ice only")
significance_mask = (significance_ds["input3a_std"].p_value > P_VALUE_CUTOFF).astype(int)
add_hatching(axs[0,1], significance_mask, x, y)

axs[1,1].pcolormesh(x, y, 100 * (acc_minmax_norm - acc_sic_only) / acc_sic_only,
                    cmap='RdBu', shading='flat', vmin=-15, vmax=15)
axs[1,1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[1,1].set_ylabel("Lead time")
axs[1,1].set_title(f"ACC diff from sea ice only")
significance_mask = (significance_ds["input3a"].p_value > P_VALUE_CUTOFF).astype(int)
add_hatching(axs[1,1], significance_mask, x, y)

axs[2,1].pcolormesh(x, y, 100 * (acc_nonorm - acc_sic_only) / acc_sic_only,
                    cmap='RdBu', shading='flat', vmin=-15, vmax=15)
axs[2,1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[2,1].set_ylabel("Lead time")
axs[2,1].set_title(f"ACC diff from sea ice only")
significance_mask = (significance_ds["input3a_dev"].p_value > P_VALUE_CUTOFF).astype(int)
add_hatching(axs[2,1], significance_mask, x, y)

for i in range(3):
    for j in range(2):
        axs[i,j].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

cbar_ax = fig.add_axes([0.1, -0.05, 0.35, 0.02])
cbar_ax2 = fig.add_axes([0.6, -0.05, 0.35, 0.02])

plt.colorbar(cax, cax=cbar_ax, label=r'ACC', orientation='horizontal')
plt.colorbar(cax2, cax=cbar_ax2, label=r'Percent change (%)', orientation='horizontal')

plt.tight_layout()
plt.savefig("figures/cesm/exp1_ACC_diff_sst_norm.jpg", dpi=300, bbox_inches='tight')

## noise experiment

In [None]:
from src.experiment_configs.exp1_inputs import input_noise

cdict = load_globals(input_noise)
climatology_broadcast = get_broadcast_climatology(cdict, "test")
targets = load_targets(cdict, "test", add_climatology_to_anomaly=False)

num_nn_ens_members = 5
pred = load_model_predictions(cdict, nn_ens_avg=False, climatology_broadcasted=None, 
                                add_climatology_to_anomaly=False)

acc_temp_list = []
for i in range(num_nn_ens_members): 
    acc_temp = calculate_acc(pred.isel(nn_member_id=i), 
                        targets, dim=("x","y"), aggregate=False)
    acc_temp_list.append(acc_temp)
acc_noise = xr.concat(acc_temp_list, dim="nn_member_id")
acc_agg_noise = aggregate_acc(acc_noise, dim=("x","y"))

save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")
acc_noise.to_netcdf(os.path.join(save_dir, f"ACC_input_noise.nc"))
acc_agg_noise.to_netcdf(os.path.join(save_dir, f"ACC_agg_input_noise.nc"))

In [None]:
# load saved acc 
acc = {}
acc_agg = {}

save_dir = os.path.join(config_cesm.ANALYSIS_RESULTS_DIRECTORY, "exp1_inputs")
for key in ["input2", "input_noise"]:
    acc[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_{key}.nc"))["__xarray_dataarray_variable__"]
    acc_agg[key] = xr.open_dataset(os.path.join(save_dir, f"ACC_agg_{key}.nc"))["__xarray_dataarray_variable__"]

In [None]:
significance_ds = {}

for k in ["input_noise",]: 
    print(f"Computing bootstrap significance test for {k}")

    significance_ds[k] = bootstrap_acc_significance(acc["input2"], acc[k])

for k, ds in significance_ds.items():
    significance_ds[k] = roll_acc(ds)

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(7, 2))

x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

acc_noise = acc_agg["input_noise"].mean("nn_member_id")
acc_sic_only = acc_agg["input2"].mean("nn_member_id")

cax = axs[0].pcolormesh(x, y, acc_noise, cmap='Spectral', shading='flat', vmin=0, vmax=1)
axs[0].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[0].set_ylabel("Lead time")
axs[0].set_title(f"Sea ice + 6 channels of noise")

cax2 = axs[1].pcolormesh(x, y, 100 * (acc_noise - acc_sic_only) / acc_sic_only,
                    cmap='RdBu', shading='flat', vmin=-15, vmax=15)
axs[1].set_yticks(y_centers, labels=np.arange(1,7,1))
axs[1].set_ylabel("Lead time")
axs[1].set_title(f"ACC diff from sea ice only")
significance_mask = (significance_ds["input_noise"].p_value > P_VALUE_CUTOFF).astype(int)
add_hatching(axs[1], significance_mask, x, y)

for i in range(2):
    axs[i].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

cbar_ax = fig.add_axes([0.1, -0.05, 0.35, 0.05])
cbar_ax2 = fig.add_axes([0.6, -0.05, 0.35, 0.05])

plt.colorbar(cax, cax=cbar_ax, label=r'ACC', orientation='horizontal')
plt.colorbar(cax2, cax=cbar_ax2, label=r'Percent change (%)', orientation='horizontal')

plt.tight_layout()
plt.savefig("figures/cesm/exp1_ACC_diff_noise.jpg", dpi=300, bbox_inches='tight')

## Temporal dependence of ACC

### composite analysis
Here we select instances where ACC of input4 exceeds the ACC of input2 by at least `diff_thresh=0.3`. The visualizations are for initialization in October, so that we capture the  

In [None]:
test_members = ['r2i1251p1f1', 'r2i1281p1f1', 'r2i1301p1f1', 'r3i1041p1f1']
data_dir = '/scratch/users/yucli/cesm_data_processed/data_pairs/seaice_plus_temp_plus_atm_minmax'

inputs_da = []
for member_id in test_members: 
    inputs_da.append(xr.open_dataset(os.path.join(data_dir, f"inputs_member_{member_id}.nc")))

inputs_da = xr.concat(inputs_da, dim="member_id")

In [None]:
acc_diff_all = acc["input4"].mean("nn_member_id") - acc["input2"].mean("nn_member_id")

diff_thresh = 0.3
acc_diff = acc_diff_all.where(acc_diff_all > diff_thresh, drop=True)
mask = acc_diff.notnull().any(dim="lead_time")

In [None]:
inputs_subset = inputs_da.where(mask, drop=True).data

In [None]:
targets_da = []
for member_id in test_members: 
    targets_da.append(xr.open_dataset(os.path.join(data_dir, f"targets_member_{member_id}.nc")))

targets_da = xr.concat(targets_da, dim="member_id")
targets_subset = targets_da.where(mask, drop=True).data

month_init_targets = targets_subset.where(targets_subset.start_prediction_month.dt.month == 10, drop=True).mean(("member_id", "start_prediction_month"))

In [None]:
month_init_inputs = inputs_subset.where(inputs_subset.start_prediction_month.dt.month == 10, drop=True)
print(month_init_inputs[:,:,0,0,0].notnull().sum().values)
month_init_inputs = month_init_inputs.mean(("member_id", "start_prediction_month"))

fig, axs = plt.subplots(nrows=6, ncols=6, figsize=(12,12), sharex=True, sharey=True)

axs = axs.flatten()

for i in range(30):
    axs[i].pcolormesh(month_init_inputs.isel(channel=i), vmin=-0.1, vmax=0.1, cmap="RdBu_r")
    axs[i].set_title(month_init_inputs.channel[i].values)

for i in range(30, 36): 
    lead = i - 30
    axs[i].pcolormesh(month_init_targets.isel(lead_time=lead), vmin=-0.1, vmax=0.1, cmap="RdBu_r")
    axs[i].set_title(f"Target (lead {lead + 1})")

In [None]:
pred_2 = load_model_predictions(load_globals(input2), nn_ens_avg=False, climatology_broadcasted=None, 
                                add_climatology_to_anomaly=False)
pred_4 = load_model_predictions(load_globals(input4), nn_ens_avg=False, climatology_broadcasted=None, 
                                add_climatology_to_anomaly=False)

pred_2_subset = pred_2.where(mask, drop=True).mean("nn_member_id")
pred_4_subset = pred_4.where(mask, drop=True).mean("nn_member_id")

month_init_pred2 = pred_2_subset.where(pred_2_subset.start_prediction_month.dt.month == 10, drop=True).mean(("member_id", "start_prediction_month"))
month_init_pred4 = pred_4_subset.where(pred_4_subset.start_prediction_month.dt.month == 10, drop=True).mean(("member_id", "start_prediction_month"))

In [None]:
import cartopy.crs as ccrs

month_init_inputs = inputs_subset.where(inputs_subset.start_prediction_month.dt.month == 10, drop=True)
print(month_init_inputs[:,:,0,0,0].notnull().sum().values)
month_init_inputs = month_init_inputs.mean(("member_id", "start_prediction_month"))

lon = reference_grid.lon.data
lat = reference_grid.lat.data

fig, axs = plt.subplots(nrows=8, ncols=6, figsize=(12,16), sharex=True, sharey=True,
                        subplot_kw={'projection': ccrs.SouthPolarStereo()})

axs = axs.flatten()

for i in range(30):
    axs[i].pcolormesh(lon, lat, month_init_inputs.isel(channel=i), transform=ccrs.PlateCarree(), vmin=-0.1, vmax=0.1, cmap="RdBu_r")
    axs[i].set_title(month_init_inputs.channel[i].values)

for i in range(30, 36): 
    lead = i - 30
    axs[i].pcolormesh(lon, lat, month_init_targets.isel(lead_time=lead), transform=ccrs.PlateCarree(), vmin=-0.1, vmax=0.1, cmap="RdBu_r")
    axs[i].set_title(f"Target (lead {lead + 1})")

for i in range(36, 42):
    axs[i].pcolormesh(lon, lat, month_init_pred2.isel(lead_time=i-36),transform=ccrs.PlateCarree(), vmin=-0.1, vmax=0.1, cmap="RdBu_r")
    axs[i].set_title(f"input1 pred")

for i in range(42, 48):
    axs[i].pcolormesh(lon, lat, month_init_pred4.isel(lead_time=i-42), transform=ccrs.PlateCarree(), vmin=-0.1, vmax=0.1, cmap="RdBu_r")
    axs[i].set_title(f"input3 pred")

for i in range(0, 48):
    axs[i].coastlines()

plt.savefig("figures/cesm/exp1_inputs_composite.jpg", dpi=300, bbox_inches='tight')

In [None]:
from sklearn.decomposition import PCA

In [None]:
input_vars = ["psl_lag1", "icefrac_lag1", "geopotential_lag1", "temp_lag1"]

eofs_dict = {}
pcs_dict = {}
var_exp_dict = {}
n_components = 4

for input_var in input_vars:
    stacked = inputs_da["data"].sel(channel=input_var).stack(samples=("member_id", "start_prediction_month"))
    nx, ny, nsamples = stacked.shape

    anom_matrix = (stacked - stacked.mean(("x","y"))).values.reshape(nx * ny, nsamples).T

    pca = PCA(n_components=n_components)
    pca.fit(anom_matrix)
    pcs = pca.transform(anom_matrix)
    eofs_flat = pca.components_ 
    eofs = eofs_flat.reshape(n_components, ny, nx)

    eofs_dict[input_var] = xr.DataArray(
        eofs,
        dims=("mode", "y", "x"),
        coords={"mode": np.arange(n_components), "y": stacked.y, "x": stacked.x},
        name="EOF"
    )
    
    samples_index = stacked["samples"]
    pcs_dict[input_var] = xr.DataArray(
        pcs,
        dims=("samples", "mode"),
        coords={"samples": samples_index, "mode": np.arange(n_components)},
        name="PC"
    ).unstack("samples")

    var_exp_dict[input_var] = xr.DataArray(
        pca.explained_variance_ratio_,
        dims=("mode",),
        coords={"mode": np.arange(n_components)},
        name="VarianceExplained"
    )



In [None]:
plt.figure(figsize=(5,4))
plt.pcolormesh(eofs_dict["psl_lag1"].isel(mode=0), cmap="RdBu")
plt.title("Sea level pressure (leading mode)")
plt.colorbar()
plt.savefig("figures/illustrations/sam.jpg",dpi=300,bbox_inches='tight')

In [None]:
time = pcs_dict["icefrac_lag1"].start_prediction_month
month = 10

fig, axs = plt.subplots(figsize=(8,5), nrows=3, ncols=6, sharex=True)

cmap = plt.get_cmap("plasma").copy()

psl_pc = np.abs(pcs_dict["psl_lag1"].isel(mode=0).where(time.dt.month == month, drop=True).values.flatten())
acc_input2 = acc["input2"].mean("nn_member_id").where(time.dt.month == month, drop=True)
acc_input4 = acc["input4"].mean("nn_member_id").where(time.dt.month == month, drop=True)

months = ['Oct', 'Nov', 'Dec', 'Jan', 'Feb', 'Mar']
for i in range(6):
    axs[0,i].hist2d(psl_pc, acc_input2.isel(lead_time=i).values.flatten(),
                    bins=[15,15], density=True, cmap="cubehelix_r")

    axs[1,i].hist2d(psl_pc, acc_input4.isel(lead_time=i).values.flatten(), 
                    bins=[15,15], density=True, cmap="cubehelix_r")

    axs[2,i].hist2d(psl_pc, (acc_input4 - acc_input2).isel(lead_time=i).values.flatten(), 
                    bins=[15,19], density=True, cmap="cubehelix_r")

    axs[0,i].set_ylim([-0.25, 1])
    axs[1,i].set_ylim([-0.25, 1])
    axs[2,i].set_ylim([-0.4, 0.4])

    for j in range(3):
        axs[j,i].set_xlim([0, 20])
        axs[j,i].set_xticks([0, 10, 20])
        
        if i > 0:
            axs[j,i].set_yticklabels([])
        
        axs[j,i].grid(color='0.3')
    
    axs[0,0].set_ylabel("ACC (input1)") #this uses the same nomenclature as the paper 
    axs[1,0].set_ylabel("ACC (input3)") 
    axs[2,0].set_ylabel("ACC difference") 
    axs[2,i].set_xlabel(rf"abs($PC_1^{{\mathrm{{psl}}}}$)")
    axs[0,i].set_title(f"Lead {i+1} ({months[i]})", fontsize=11)

plt.savefig("figures/cesm/exp1_SAM_ACC_joint_dist_sep_init.jpg", bbox_inches='tight', dpi=300)

In [None]:
var_exp_dict["psl_lag1"]