In [46]:
import xarray as xr
import numpy as np 
import os
import types
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd
import gc
import string

from src.utils import util_cesm
from src import config_cesm
from src.models import models_util
from src.models import models 
import cartopy.crs as ccrs

reference_grid = util_cesm.generate_sps_grid()

def add_hatching(ax, significance_mask, x_edges, y_edges, hatch='///', edgecolor='k'):
    Ny, Nx = significance_mask.shape

    for i in range(Ny):
        for j in range(Nx):
            if significance_mask[i, j]:
                rect = mpatches.Rectangle(
                    (x_edges[j], y_edges[i]),
                    x_edges[j+1] - x_edges[j],
                    y_edges[i+1] - y_edges[i],
                    hatch=hatch,
                    fill=False,
                    edgecolor=edgecolor,
                    linewidth=0
                )
                ax.add_patch(rect)


## Plot loss curves

In [53]:
import wandb

api = wandb.Api()

run_ids = {
    "exp2_vol1": "swp4ff58", 
    "exp2_vol2": "kf0n8hcz", 
    "exp2_vol3": "869w9ohv",
    "exp2_vol4": ["nyurxe4m", "on2qsrjy"]
}

run_histories = {}

for exp, run_id in run_ids.items():
    if isinstance(run_id, list): 
        runs = [f"ychnli-stanford-university/sea-ice-prediction/{r}" for r in run_id]
        run_histories[exp] = np.concatenate([api.run(r).history(keys=["train_loss", "val_loss"]).to_numpy()[:,1:] for r in runs], axis=0)
    else:
        run_id = f"ychnli-stanford-university/sea-ice-prediction/{run_id}"
        run_histories[exp] = api.run(run_id).history(keys=["train_loss", "val_loss"]).to_numpy()[:,1:]


In [None]:
plt.figure(figsize=(6,4))

colors = ["tab:blue", "tab:orange", "tab:red", "tab:purple"]

data_vol = [1, 4, 16, 64]
batches_per_cesm_ens_member = 31 

for i, (exp, data) in enumerate(run_histories.items()):
    x = np.arange(len(data[:,0]))
    plt.plot((x+0.01) * data_vol[i] * batches_per_cesm_ens_member, data[:,0], label=f"{data_vol[i]} ens members - train", linestyle="dotted", color=colors[i])
    plt.plot((x+0.5) * data_vol[i] * batches_per_cesm_ens_member, data[:,1], label=f"{data_vol[i]} ens members - val", color=colors[i])

plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.semilogx()
plt.xlim(10, 3e4)
plt.legend()
plt.title("Data scaling loss curves")
plt.savefig("figures/cesm_new/exp2_loss.pdf", bbox_inches='tight')

## Compute ACC

In [13]:
cnames = ["vol1", "vol2", "vol3", "vol4"]

acc = {}
significance = {}

for i, config in enumerate(cnames): 
    acc[config] = xr.open_dataset(f"/scratch/users/yucli/sicpred_model_predictions/exp2_{config}/diagnostics/acc_agg.nc")["acc"].mean("nn_member_id")
    if i > 0:
        significance[config] = xr.open_dataset(f"/oak/stanford/groups/earlew/yuchen/sicpred/analysis_results/confidence_intervals/exp2_{cnames[i-1]}_exp2_{config}_acc.nc")["p_value"]
    

In [None]:
fig, axs = plt.subplots(nrows=4, ncols=2, figsize=(8,8), sharex=True, sharey=True)
P_VALUE_CUTOFF = 0.01

fig.delaxes(axs[0,1])
x = np.arange(13)
y = np.arange(7)
x_centers = (x[:-1] + x[1:]) / 2
y_centers = (y[:-1] + y[1:]) / 2

vol_factor = [1, 4, 16, 64]

# Create placeholders for colorbars
cax1 = None  # main accuracy
cax2 = None  # difference

for i, key in enumerate(cnames):
    # Main plot
    cax = axs[i, 0].pcolormesh(x, y, acc[key], cmap='Spectral', shading='flat', vmin=0, vmax=1)
    if cax1 is None:
        cax1 = cax  # save for colorbar

    axs[i, 0].set_yticks(y_centers, labels=np.arange(1, 7, 1))
    axs[i, 0].set_ylabel("Lead time")

    if i == 0:
        axs[i, 0].set_title(rf"{vol_factor[i]} ens. member ($n=1956$)", fontsize=12)
    else:
        axs[i, 0].set_title(rf"{vol_factor[i]} ens. members ($n={1956 * vol_factor[i]}$)", fontsize=12)
        axs[i, 1].set_title(f"$\Delta$ ACC ({vol_factor[i-1]} â†’ {vol_factor[i]})")
        diff_cax = axs[i, 1].pcolormesh(
            x, y, (acc[cnames[i]] - acc[cnames[i-1]]),
            cmap="Blues", shading="flat", vmin=0, vmax=0.12
        )
        significance_mask = (significance[key] > P_VALUE_CUTOFF).astype(int)
        add_hatching(axs[i, 1], significance_mask, x, y)
        if cax2 is None:
            cax2 = diff_cax 

axs[3, 0].set_xticks(x_centers, labels=["J","F","M","A","M","J","J","A","S","O","N","D"])

# Add colorbars in the top-right empty space
from mpl_toolkits.axes_grid1 import make_axes_locatable

divider = make_axes_locatable(axs[0, 0])

cbar_ax1 = fig.add_axes([0.55, 0.85, 0.35, 0.02]) 
cbar_ax2 = fig.add_axes([0.55, 0.76, 0.35, 0.02])

fig.colorbar(cax1, cax=cbar_ax1, orientation='horizontal', label='ACC')
fig.colorbar(cax2, cax=cbar_ax2, orientation='horizontal', extend='max')
plt.subplots_adjust(hspace=0.4)


panel_labels = list(string.ascii_lowercase)
for i in range(4):
    for j in range(2):
        k = 2*j + i
        axs[i,j].annotate(
            f"{panel_labels[k]})",
            xy=(0, 1), xycoords="axes fraction",
            xytext=(-0.1, 1.03), textcoords="axes fraction",
            ha="left", va="bottom",
            fontsize=11, fontweight="bold"
        )

plt.savefig("figures/cesm_new/exp2_ACC.pdf", bbox_inches='tight')


# Plot a bunch of examples

In [None]:
from src.experiment_configs.exp2_data_volume import vol1 
from src.experiment_configs.exp2_data_volume import vol4

names = ["vol1", "vol4"]

acc = {}
targets = {}
predictions = {}
save_dir = config_cesm.ANALYSIS_RESULTS_DIRECTORY 

for i, config in enumerate([vol1, vol4]): 
    key = names[i]
    print(f"Loading broadcasted climatology for {key}")
    cdict = load_globals(config) 
    climatology_broadcast = get_broadcast_climatology(cdict, "test")

    print("Calculating and saving acc...")
    diagnostics = ModelDiagnostics(cdict, True, climatology_broadcast)
    acc[key] = diagnostics.calculate_acc(dim=("x","y"), aggregate=False)
    predictions[key] = diagnostics.predictions_anomaly 
    targets[key] = diagnostics.targets_anomaly

    # save_dir = config_cesm.ANALYSIS_RESULTS_DIRECTORY 
    # acc[key].to_netcdf(os.path.join(save_dir, f"{key}_dt_ACC.nc"))
    # acc_persist[key].to_netcdf(os.path.join(save_dir, f"{key}_dt_ACC_persist.nc"))
    # acc_clim[key].to_netcdf(os.path.join(save_dir, f"{key}_dt_ACC_clim.nc"))

    del climatology_broadcast, diagnostics
    gc.collect()

In [None]:
# get a list of where ACC(vol1) << ACC(vol4) 
data = (acc["vol4"] - acc["vol1"]).max("lead_time")

k = 10
flat_indices = np.argpartition(data.values.ravel(), -k)[-k:]
topk_sorted = flat_indices[np.argsort(data.values.ravel()[flat_indices])[::-1]]
x_inds, y_inds = np.unravel_index(topk_sorted, data.shape)
coords = list(zip(data['start_prediction_month'].values[x_inds], data['member_id'].values[y_inds]))

for i in range(k):
    plot_pred_example(predictions["vol1"], predictions["vol4"], targets["vol1"], x_inds[i], y_inds[i])
print("Indices:", list(zip(x_inds, y_inds)))
print("Coordinate pairs:", coords)


In [None]:
# this gets where the capable model does best at long lead times
data = (acc["vol4"]).isel(lead_time=5)

k = 10
flat_indices = np.argpartition(data.values.ravel(), -k)[-k:]
topk_sorted = flat_indices[np.argsort(data.values.ravel()[flat_indices])[::-1]]
x_inds, y_inds = np.unravel_index(topk_sorted, data.shape)
coords = list(zip(data['start_prediction_month'].values[x_inds], data['member_id'].values[y_inds]))

def plot_pred_example(da_pred1, da_pred2, da_truth, start_prediction_month, member_id, savepath=None):
    months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 
            'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
    lon = reference_grid.lon.data
    lat = reference_grid.lat.data


    fig, axs = plt.subplots(figsize=(12, 6), nrows=3, ncols=6, sharex=True, sharey=True, 
                            subplot_kw={'projection': ccrs.SouthPolarStereo()})

    for i in range(6):
        pred1 = da_pred1.isel(start_prediction_month=start_prediction_month, member_id=member_id, lead_time=i)
        pred2 = da_pred2.isel(start_prediction_month=start_prediction_month, member_id=member_id, lead_time=i)
        truth = da_truth.isel(start_prediction_month=start_prediction_month, member_id=member_id, lead_time=i)

        axs[0,i].pcolormesh(lon, lat, pred1, transform=ccrs.PlateCarree(), vmin=-0.5, vmax=0.5, cmap="RdBu_r")
        axs[1,i].pcolormesh(lon, lat, pred2, transform=ccrs.PlateCarree(), vmin=-0.5, vmax=0.5, cmap="RdBu_r")
        axs[2,i].pcolormesh(lon, lat, truth, transform=ccrs.PlateCarree(), vmin=-0.5, vmax=0.5, cmap="RdBu_r")

        # print acc 
        acc_1 =  acc["vol1"].isel(start_prediction_month=start_prediction_month, member_id=member_id, lead_time=i).values.round(2)
        acc_2 =  acc["vol4"].isel(start_prediction_month=start_prediction_month, member_id=member_id, lead_time=i).values.round(2)
        axs[0,i].text(25, -60, acc_1, transform=ccrs.PlateCarree())
        axs[1,i].text(25, -60, acc_2, transform=ccrs.PlateCarree())
        axs[0,i].set_title(f"Lead {i+1} ({months[(start_prediction_month + i + 1) % 12]})")
        for j in range(3):
            axs[j,i].coastlines()
            axs[j,0].yaxis.set_visible(True) 
            axs[j,0].set_yticks([])
    axs[0,0].set_ylabel("Worse (vol1) UNet")
    axs[1,0].set_ylabel("Better (vol4) UNet")
    axs[2,0].set_ylabel("Truth")

    start_pred_mon = da_truth.start_prediction_month.isel(start_prediction_month=start_prediction_month)
    member_id_val = da_truth.member_id.isel(member_id=member_id).values
    plt.suptitle(f"Init month: {start_pred_mon.values.astype('datetime64[M]')}, CESM member: {member_id_val}")
    if savepath is not None:
        plt.savefig(savepath, dpi=250, bbox_inches='tight')
    plt.show()

for i in range(k):
    plot_pred_example(predictions["vol1"], predictions["vol4"], targets["vol1"], x_inds[i], y_inds[i],
                    savepath=f"figures/cesm/sample_predictions/exp2_vol4_bestacc_{i}.jpg")
print("Indices:", list(zip(x_inds, y_inds)))
print("Coordinate pairs:", coords)


In [None]:
# this gets where the capable model does best at long lead times
init_month=8
data = acc["vol4"].where(acc["vol4"].start_prediction_month.dt.month == init_month, drop=True).isel(lead_time=5)

k = 10
flat_indices = np.argpartition(data.values.ravel(), -k)[-k:]
topk_sorted = flat_indices[np.argsort(data.values.ravel()[flat_indices])[::-1]]
x_inds, y_inds = np.unravel_index(topk_sorted, data.shape)
coords = list(zip(data['start_prediction_month'].values[x_inds], data['member_id'].values[y_inds]))
x_inds = x_inds * 12 + init_month - 1

for i in range(k):
    plot_pred_example(predictions["vol1"], predictions["vol4"], targets["vol1"], x_inds[i], y_inds[i],
                    savepath=f"figures/cesm/sample_predictions/exp2_vol4_best_feb_acc_{i}.jpg")

print("Indices:", list(zip(x_inds, y_inds)))
print("Coordinate pairs:", coords)


In [None]:

init_month=8
data = -acc["vol4"].where(acc["vol4"].start_prediction_month.dt.month == init_month, drop=True).isel(lead_time=5)

k = 10
flat_indices = np.random.choice(data.size, size=k, replace=False)
x_inds, y_inds = np.unravel_index(flat_indices, data.shape)
coords = list(zip(data['start_prediction_month'].values[x_inds], data['member_id'].values[y_inds]))
x_inds = x_inds * 12 + init_month - 1

for i in range(k):
    plot_pred_example(predictions["vol1"], predictions["vol4"], targets["vol1"], x_inds[i], y_inds[i],
                    savepath=f"figures/cesm/sample_predictions/random_feb_{i}.jpg")

print("Indices:", list(zip(x_inds, y_inds)))
print("Coordinate pairs:", coords)
