In [None]:
import json
import os
import yaml
import importlib
from pathlib import Path
import datetime
import shutil

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import pandas as pd



from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc


from torch import utils
import lightning.pytorch as pl


from deepS2S.model import ViTLSTM, IndexLSTM

from deepS2S.utils.utils import statics_from_config
from deepS2S.utils.utils_data import generate_clim_pred, load_data
from deepS2S.utils.utils_model import test_model_and_data, best_model_folder
from deepS2S.utils.utils_evaluation import *
from deepS2S.utils.utils_plot import *

In [None]:
import matplotlib as mpl
mpl.get_configdir()

'/mnt/beegfs/home/bommer1/.config/matplotlib'

In [53]:
plt.style.use('seaborn-v0_8-whitegrid')


## Model and Prediction Function

# Results

In [None]:
# Set hyperparameters.
arch_type = 'Index_LSTM' # 'LSTM' # 'ViT'
cm_list = ['#7fbf7b','#1b7837','#762a83','#9970ab','#c2a5cf'] 
regimes = ['SB', 'NAO-', 'AR', 'NAO+']

cfd = Path(os.getcwd()).parent.absolute()
root_path = str(Path(cfd).parent.absolute()) + '/Data'

data_path = f"{root_path}/"
res_path = f"{root_path}/" + f'/Results/'

config_vit = yaml.load(open(f'{cfd}/config/config_vit_lstm.yaml'), Loader=yaml.FullLoader)

config_vit['net_root'] = str(cfd.parent.absolute()) + f'/Data/Network/'
config_vit['root'] = str(cfd.parent.absolute()) + f'/Data/Network/Sweeps/'
config_vit['data_root'] = str(cfd.parent.absolute()) + f'/Data'

strt_yr = config_vit.get('strt','')
trial_num = config_vit.get('version', '')
norm_opt = config_vit.get('norm_opt','')
arch = config_vit.get('arch', 'ViT-LSTM')
tropics_vit = config_vit.get('tropics', '')
temp_scaling = config_vit.get('temp_scaling', False)

stat_dir =  config_vit['net_root'] + f'Statistics/{arch}'
result_path = f'{data_path}Results/Statistics/{arch}/'
results_directory = Path(f'{result_path}version_{strt_yr}{trial_num}_{norm_opt}{tropics_vit}/')
os.makedirs(results_directory, exist_ok=True)


mod_name = 'ViT_LSTM'
architecture = ViTLSTM.ViT_LSTM

        
arch_lstm = 'LSTM'
config_lstm = yaml.load(open(f'{cfd}/config/config_lstm.yaml'), Loader=yaml.FullLoader)

config_lstm['net_root'] = str(cfd.parent.absolute()) + f'/Data/Network/'
config_lstm['root'] = str(cfd.parent.absolute()) + f'/Data/Network/Sweeps/'
config_lstm['data_root'] = str(cfd.parent.absolute()) + f'/Data'

strt_yr = config_lstm.get('strt','')
trial_num = config_lstm.get('version', '')

norm_opt = config_lstm.get('norm_opt','')
arch = config_lstm.get('arch', 'ViT')
tropics = config_lstm.get('tropics', '')
temp_scaling = config_lstm.get('temp_scaling', False)

stat_dir_lstm =  config_lstm['net_root'] + f'Statistics/{arch_lstm}/'
result_pat_lstm = f'{res_path}Statistics/{arch_lstm}/'
results_dir_lstm = Path(f'{result_pat_lstm}version_{strt_yr}{trial_num}_{norm_opt}{tropics}/')
os.makedirs(results_dir_lstm, exist_ok=True)

arch_index = 'Index_LSTM'
config_index = yaml.load(open(f'{cfd}/config/config_index_lstm.yaml'), Loader=yaml.FullLoader)

config_index['net_root'] = str(cfd.parent.absolute()) + f'/Data/Network/'
config_index['root'] = str(cfd.parent.absolute()) + f'/Data/Network/Sweeps/'
config_index['data_root'] = str(cfd.parent.absolute()) + f'/Data'

strt_yr = config_index.get('strt','')
trial_num = config_index.get('version', '')
norm_opt = config_index.get('norm_opt','')
arch = config_index.get('arch', 'ViT')
tropics = config_index.get('tropics', '')
temp_scaling = config_index.get('temp_scaling', False)


stat_dir_index =  config_index['net_root'] + f'Statistics/{arch_index}/'
result_pat_index = f'{res_path}Statistics/{arch_index}/'
results_dir_index = Path(f'{result_pat_index}version_{strt_yr}{trial_num}_{norm_opt}{tropics}/')
os.makedirs(results_dir_index, exist_ok=True)
mod_name = 'Index_LSTM'
architecture_index = IndexLSTM.Index_LSTM



test_loader, data_set, cls_wt, test_set, infos = load_data(config_vit)

var_comb_index = config_index['var_comb']

data_info, seasons = statics_from_config(config_vit)

~/WiOSTNN/Version1/data/ERA5/datasets/z_500_1.40625deg_1980-2023_northern_hemi_2d_NAEregimes.nc
~/WiOSTNN/Version1/data/ERA5/datasets/z_500_1.40625deg_1980-2023_northern_hemi_2d_NAEregimes.nc
~/WiOSTNN/Version1/data/ERA5/datasets/z_500_1.40625deg_1980-2023_northern_hemi_2d_NAEregimes.nc
~/WiOSTNN/Version1/data/ERA5/datasets/z_500_1.40625deg_1980-2023_northern_hemi_2d_NAEregimes.nc
~/WiOSTNN/Version1/data/ERA5/datasets/z_500_1.40625deg_1980-2023_northern_hemi_2d_NAEregimes.nc


In [None]:
# Load collected data.
exp_dir =  f"{stat_dir}version_{strt_yr}{trial_num}_{norm_opt}{tropics_vit}/"
pths = [xs for xs in Path(exp_dir).iterdir() if xs.is_dir()]

if temp_scaling:
    data_collect = np.load(f'{results_directory}/collected_loop_data_{len(pths)-1}_temp_scale.npz')
    data_result = np.load(f'{results_directory}/accuracy_{len(pths)-1}model_temp_scale.npz')
else:
    data_collect = np.load(f'{results_directory}/collected_loop_data_{len(pths)-1}.npz')
    data_result = np.load(f'{results_directory}/accuracy_{len(pths)-1}model.npz')
    
persistance = data_collect['persistance'] 
dates = data_collect['dates'] 
daytimes = data_collect['daytimes']
loop_probabilities_vit = data_collect['loop_probabilities']
loop_classes_vit = data_collect['loop_classes']
targets= data_collect['targets']
predictions_baseline = data_collect['predictions_baseline']



In [58]:
# Load collected data.
exp_dir_lstm =  f"{stat_dir_lstm}version_{strt_yr}{trial_num}_{norm_opt}{tropics_vit}/"
pths_lstm = [xs for xs in Path(exp_dir_lstm).iterdir() if xs.is_dir()]

if temp_scaling:
    data_collect_lstm = np.load(f'{results_dir_lstm}/collected_loop_data_{len(pths_lstm)}_temp_scale.npz')
    data_result_lstm = np.load(f'{results_dir_lstm}/accuracy_{len(pths_lstm)-1}model_temp_scale.npz')
else:
    data_collect_lstm = np.load(f'{results_dir_lstm}/collected_loop_data_{len(pths_lstm)}.npz')
    data_result_lstm = np.load(f'{results_dir_lstm}/accuracy_{len(pths_lstm)}model.npz')
    
    
persistance = data_collect_lstm['persistance']
loop_probabilities_lstm = data_collect_lstm['loop_probabilities']
loop_classes_lstm = data_collect_lstm['loop_classes']
targets_lstm= data_collect_lstm['targets']

In [59]:
# Load collected data.
exp_dir_index =  f"{stat_dir_index}version_{strt_yr}{trial_num}_{norm_opt}/"
pths_index= [xs for xs in Path(exp_dir_index).iterdir() if xs.is_dir()]

if temp_scaling:
    data_collect_index = np.load(f'{results_dir_index}/collected_loop_data_{len(pths_index)-1}_temp_scale.npz')
    data_result_index = np.load(f'{results_dir_index}/accuracy_{len(pths_index)-1}model_temp_scale.npz')
else:
    data_collect_index = np.load(f'{results_dir_index}/collected_loop_data_{len(pths_index)-1}.npz')
    data_result_index = np.load(f'{results_dir_index}/accuracy_{len(pths_index)-1}model.npz')
    
loop_probabilities_index = data_collect_index['loop_probabilities']
loop_classes_index = data_collect_index['loop_classes']
targets_index  = data_collect_index['targets']


In [60]:
input_reg = []
for input, output, weeks, days in data_set:
    if arch_type == 'Index_LSTM':
        input_reg.append(input[1][None,:,-4:].numpy().squeeze())
    else:
        input_reg.append(np.array(input[1]).squeeze())


input_reg = np.concatenate(input_reg).reshape((predictions_baseline.shape[0],
                                                           predictions_baseline.shape[1],4))

In [61]:
# load climatology
smoothing = 7
num_m = 14
dtset_name = config_vit['data']['dataset_name2']

clim_prob = xr.load_dataarray(f'/mnt/beegfs/home/bommer1/WiOSTNN/Version1/data/{dtset_name}/climatology/NAE_{num_m}eofs_prob_{smoothing}days_climatology_1980_2009.nc')
predictions_clim = generate_clim_pred(clim_prob, dates)
predictions_clim_classes = np.argmax(predictions_clim, 2)


In [62]:
# Load the CSV file into a DataFrame
leads = [6,13,20,27,34,40]


# Load the CSV file into a DataFrame
file_path = '/mnt/beegfs/home/bommer1/WiOSTNN/Data/Results/Statistics/ViT/version_1980_calibrated_olr/csi_results_2501.csv'
df_csi = pd.read_csv(file_path)

file_path = '/mnt/beegfs/home/bommer1/WiOSTNN/Data/Results/Statistics/ViT/version_1980_calibrated_olr/acc.csv'
df_acc = pd.read_csv(file_path)

file_path = '/mnt/beegfs/home/bommer1/WiOSTNN/Data/Results/Statistics/ViT/version_1980_calibrated_olr/precision_hindcast.csv'
df_prec = pd.read_csv(file_path)

# Access CSI
hindcast_csi_mean_values = df_csi['weighted mean'].values
hindcast_mean_csi = hindcast_csi_mean_values[leads]
hindcast_csi = {r: []for r in regimes}
for i, r in enumerate(regimes):
    hindcast_csi[r] = df_csi[r].values[leads]
hindcast_csi_array = np.array([hindcast_csi[r] for r in regimes]).T

#Access Precision
hindcast_mean_precision = df_prec['weighted mean'].values
hindcast_mean_prec = hindcast_mean_precision[leads]
# Plot class-wise accuracy over time steps
hindcast_precision = {r: []for r in regimes}
for i, r in enumerate(regimes):
    hindcast_precision[r] = df_prec[r].values[leads]
hindcast_precision_array = np.array([hindcast_precision[r] for r in regimes])

# Access accuracy
hindcast_accuracy = df_acc['weighted mean'].values[leads]
hindcast_acc = {r: []for r in regimes}
for i, r in enumerate(regimes):
    hindcast_acc[r] = df_acc[r].values[leads]
hindcast_acc_array = np.array([hindcast_acc[r] for r in regimes]).T


In [63]:
# Load Logistic Regression.
vs_logistic = config_vit['logistic']
lr_data = np.load(f"{lr_dir}_{strt_yr}{trial_num}/results.npz")

target_lr = lr_data['targets']
predictions_lr = lr_data['probs']  


predicted_lr = predictions_lr[:predictions_baseline.shape[0],:,:]
predictions_lr_classes = np.argmax(predictions_lr, axis=-1)

## Calculate ACC, CSI and Precision

### ACC

In [64]:
loop_acc_cw_vit = []
loop_acc_cw_index = []
loop_acc_cw_lstm = []
loop_acc_vit = []
loop_acc_lstm = []
loop_acc_index = []
for mod in range(loop_classes_vit.shape[0]):
    pred_vit = loop_classes_vit[mod]
    pred_lstm = loop_classes_lstm[mod]
    pred_index = loop_classes_index[mod]

    loop_acc_cw_vit.append(classwise_ACC(pred_vit, targets)[None,...])
    loop_acc_cw_lstm.append(classwise_ACC(pred_lstm, targets_lstm)[None,...])
    loop_acc_cw_index.append(classwise_ACC(pred_index, targets_index)[None,...])

    loop_acc_vit.append(balanced_accuracy(pred_vit, targets)[0][None,...])
    loop_acc_lstm.append(balanced_accuracy(pred_lstm, targets_lstm)[0][None,...])
    loop_acc_index.append(balanced_accuracy(pred_index, targets_index)[0][None,...])

loop_acc_cw_vit = np.concatenate(loop_acc_cw_vit)
loop_acc_cw_lstm = np.concatenate(loop_acc_cw_lstm)
loop_acc_cw_index = np.concatenate(loop_acc_cw_index)

loop_acc_vit = np.concatenate(loop_acc_vit)
loop_acc_lstm = np.concatenate(loop_acc_lstm)
loop_acc_index = np.concatenate(loop_acc_index)


acc_cw_vit = np.mean(loop_acc_cw_vit, axis=0)
acc_cw_lstm = np.mean(loop_acc_cw_lstm, axis=0)
acc_cw_index = np.mean(loop_acc_cw_index, axis=0)

acc_vit = np.mean(loop_acc_vit, axis=0)
acc_lstm = np.mean(loop_acc_lstm, axis=0)
acc_index = np.mean(loop_acc_index, axis=0)

std_acc_vit = np.std(loop_acc_vit, axis=0)
std_acc_lstm = np.std(loop_acc_lstm, axis=0)
std_acc_index = np.std(loop_acc_index, axis=0)

acc_pers_cw = classwise_ACC(persistance, targets)

In [None]:
lstmIndex_stats = np.load('/mnt/beegfs/home/bommer1/Project3/WiOSTNN/Data/Results/Statistics/Index_LSTM_9cat/version_1980_calibrated/accuracy_99model.npz')
lstmIndex_pv = np.load('/mnt/beegfs/home/bommer1/Project3/WiOSTNN/Data/Results/Statistics/Index_LSTM_9cat/version_1980_calibrated_pv/accuracy_99model.npz')
lstmIndex_mjo = np.load('/mnt/beegfs/home/bommer1/Project3/WiOSTNN/Data/Results/Statistics/Index_LSTM_9cat/version_1980_calibrated_mjo/accuracy_99model.npz')

stnn_spv = np.load('/mnt/beegfs/home/bommer1/Project3/WiOSTNN/Data/Results/Statistics/ViT/version_1980_calibrated_olr_pv/accuracy_99model.npz')
stnn_olr = np.load('/mnt/beegfs/home/bommer1/Project3/WiOSTNN/Data/Results/Statistics/ViT/version_1980_calibrated_olr_sst/accuracy_99model.npz')
stnn_stats = np.load('/mnt/beegfs/home/bommer1/Project3/WiOSTNN/Data/Results/Statistics/ViT/version_1980_calibrated_olr_prob/accuracy_99model.npz')

mean_vit, std_vit = stnn_stats['mean_acc'], stnn_stats['std_acc']
mean_vit_spv, std_vit_spv = stnn_spv['mean_acc'], stnn_spv['std_acc']
mean_vit_olr, std_vit_olr = stnn_olr['mean_acc'], stnn_olr['std_acc']
mean_index, std_index = lstmIndex_stats['mean_acc'], lstmIndex_stats['std_acc']
mean_index_pv, std_index_pv = lstmIndex_pv['mean_acc'], lstmIndex_pv['std_acc']
mean_index_mjo, std_index_mjo = lstmIndex_mjo['mean_acc'], lstmIndex_mjo['std_acc']

overall_accuracy_persist, _ = balanced_accuracy(persistance, targets)
overall_accuracy_clim, _ = balanced_accuracy(predictions_clim_classes, targets)


In [None]:
# Define the models and their corresponding mean and std values
models_mean = ['Index-LSTM', 'ViT-LSTM', 'ViT-LSTM-SPV', 'ViT-LSTM-OLR', 'Index-LSTM', 'Index-LSTM-SPV', 'Index-LSTM-MJO']
mean_values = [acc_index, acc_vit, mean_vit_spv, mean_vit_olr, mean_index, mean_index_pv, mean_index_mjo]
models_std = ['LSTM-Index', 'ViT', 'probabilistic ViT-LSTM', 'ViT-LSTM-SPV', 'ViT-LSTM-OLR', 'Index-LSTM', 'Index-LSTM-SPV', 'Index-LSTM-MJO']
std_values = [ std_acc_index, std_acc_vit, std_vit, std_vit_spv, std_vit_olr, std_index, std_index_pv, std_index_mjo]

# Create DataFrames for mean and std values
mean_df = pd.DataFrame(mean_values, index=models_mean, columns=[f'Timestep {i+1}' for i in range(6)])
std_df = pd.DataFrame(std_values, index=models_std, columns=[f'Timestep {i+1}' for i in range(6)])

In [67]:
mean_df

Unnamed: 0,Timestep 1,Timestep 2,Timestep 3,Timestep 4,Timestep 5,Timestep 6
Index-LSTM,0.261042,0.304715,0.319151,0.288064,0.275478,0.258697
ViT-LSTM,0.279344,0.292949,0.318645,0.330891,0.340463,0.291559
probabilistic ViT-LSTM,0.255795,0.251397,0.27696,0.304632,0.310341,0.280104
ViT-LSTM-SPV,0.289764,0.313361,0.298223,0.299186,0.286805,0.251502
ViT-LSTM-OLR,0.243612,0.249794,0.246421,0.241659,0.247328,0.259278


In [68]:
std_df

Unnamed: 0,Timestep 1,Timestep 2,Timestep 3,Timestep 4,Timestep 5,Timestep 6
LSTM-Index,0.018401,0.015235,0.020049,0.016744,0.017393,0.013297
ViT,0.014425,0.017293,0.014274,0.017619,0.017395,0.019868
probabilistic ViT-LSTM,0.016561,0.016754,0.018566,0.018172,0.018931,0.019261
ViT-LSTM-SPV,0.0151,0.017915,0.017124,0.016476,0.016338,0.015159
ViT-LSTM-OLR,0.013416,0.015271,0.017163,0.019425,0.021216,0.023964


## Classwise Performance

### CSI

In [69]:
loop_csi_vit = []
loop_csi_index_lstm = []
loop_csi_lstm = []
loop_csi_cw_vit = []
loop_csi_cw_index_lstm = []
loop_csi_cw_lstm = []
for mod in range(loop_classes_vit.shape[0]):
    pred_vit = loop_classes_vit[mod]
    pred_lstm = loop_classes_lstm[mod]
    pred_index = loop_classes_index[mod]

    loop_csi_cw_vit.append(classwise_CSI(pred_vit, targets)[None,...])
    loop_csi_cw_lstm.append(classwise_CSI(pred_lstm, targets_lstm)[None,...])
    loop_csi_cw_index_lstm.append(classwise_CSI(pred_index, targets_index)[None,...])
    loop_csi_vit.append(CSI_multiclass(pred_vit, targets)[None,...])
    loop_csi_lstm.append(CSI_multiclass(pred_lstm, targets_lstm)[None,...])
    loop_csi_index_lstm.append(CSI_multiclass(pred_index, targets_index)[None,...])

loop_csi_vit = np.concatenate(loop_csi_vit)
loop_csi_lstm = np.concatenate(loop_csi_lstm)
loop_csi_index_lstm = np.concatenate(loop_csi_index_lstm)
loop_csi_cw_vit = np.concatenate(loop_csi_cw_vit)
loop_csi_cw_lstm = np.concatenate(loop_csi_cw_lstm)
loop_csi_cw_index_lstm = np.concatenate(loop_csi_cw_index_lstm)

csi_vit = np.mean(loop_csi_vit, axis=0)
csi_lstm = np.mean(loop_csi_lstm, axis=0)
csi_index_lstm = np.mean(loop_csi_index_lstm, axis=0)
csi_cw_vit = np.mean(loop_csi_cw_vit, axis=0)
csi_cw_lstm = np.mean(loop_csi_cw_lstm, axis=0)
csi_cw_index_lstm = np.mean(loop_csi_cw_index_lstm, axis=0)

persistance_cw_csi = classwise_CSI(persistance, targets)
persistance_csi = CSI_multiclass(persistance, targets)


##  Ablation regime-wise accuracy

In [None]:
# Load collected pv-only data.
exp_dir_pv =  f"{stat_dir}version_{strt_yr}{trial_num}_{norm_opt}{tropics_vit}_pv/"
pths_pv = [xs for xs in Path(exp_dir_pv).iterdir() if xs.is_dir()]

data_collect_pv = np.load(f'{results_directory}_pv/collected_loop_data_{len(pths_pv)-1}.npz')
    
loop_probabilities_vit_pv = data_collect_pv['loop_probabilities']
loop_classes_vit_pv = data_collect_pv['loop_classes']

# Load collected mjo-only data.
exp_dir_mjo =  f"{stat_dir_pv}version_{strt_yr}{trial_num}_{norm_opt}{tropics_vit}_sst/"
pths_mjo = [xs for xs in Path(exp_dir_mjo).iterdir() if xs.is_dir()]

data_collect_mjo = np.load(f'{results_directory_pv}_sst/collected_loop_data_{len(pths_mjo)-1}.npz')
    
loop_probabilities_vit_mjo = data_collect_mjo['loop_probabilities']
loop_classes_vit_mjo = data_collect_mjo['loop_classes']


In [None]:
# Load collected pv-only data.
exp_dir_index_pv =  f"{stat_dir_index}version_{strt_yr}{trial_num}_{norm_opt}_pv/"
pths_index_pv= [xs for xs in Path(exp_dir_index_pv).iterdir() if xs.is_dir()]

data_collect_index_pv = np.load(f'{results_dir_index}_pv/collected_loop_data_{len(pths_index_pv)-1}.npz')
    
loop_probabilities_index_pv = data_collect_index_pv['loop_probabilities']
loop_classes_index_pv = data_collect_index_pv['loop_classes']

# Load collected pv-only data.
exp_dir_index_mjo =  f"{stat_dir_index}version_{strt_yr}{trial_num}_{norm_opt}_mjo/"
pths_index_mjo= [xs for xs in Path(exp_dir_index_pv).iterdir() if xs.is_dir()]

data_collect_index_mjo = np.load(f'{results_dir_index}_mjo/collected_loop_data_{len(pths_index_mjo)-1}.npz')
    
loop_probabilities_index_mjo = data_collect_index_mjo['loop_probabilities']
loop_classes_index_mjo = data_collect_index_mjo['loop_classes']

In [None]:
loop_acc_cw_vit_pv = []
loop_acc_cw_index_pv = []
loop_acc_cw_vit_mjo = []
loop_acc_cw_index_mjo = []

loop_acc_vit_pv = []
loop_acc_index_pv = []
loop_acc_vit_mjo = []
loop_acc_index_mjo = []
for mod in range(loop_classes_vit.shape[0]):
    pred_vit_pv = loop_classes_vit_pv[mod]
    pred_index_pv = loop_classes_index_pv[mod]

    pred_vit_mjo = loop_classes_vit_mjo[mod]
    pred_index_mjo = loop_classes_index_mjo[mod]

    loop_acc_cw_vit_pv.append(classwise_ACC(pred_vit_pv, targets)[None,...])
    loop_acc_cw_index_pv.append(classwise_ACC(pred_index_pv, targets_index)[None,...])
    loop_acc_cw_vit_mjo.append(classwise_ACC(pred_vit_mjo, targets)[None,...])
    loop_acc_cw_index_mjo.append(classwise_ACC(pred_index_mjo, targets_index)[None,...])

    loop_acc_vit_pv.append(balanced_accuracy(pred_vit_pv, targets)[0][None,...])
    loop_acc_index_pv.append(balanced_accuracy(pred_index_pv, targets_index)[0][None,...])
    loop_acc_vit_mjo.append(balanced_accuracy(pred_vit_mjo, targets)[0][None,...])
    loop_acc_index_mjo.append(balanced_accuracy(pred_index_mjo, targets_index)[0][None,...])

loop_acc_cw_vit_pv = np.concatenate(loop_acc_cw_vit_pv)
loop_acc_cw_index_pv = np.concatenate(loop_acc_cw_index_pv)
loop_acc_cw_vit_mjo = np.concatenate(loop_acc_cw_vit_mjo)
loop_acc_cw_index_mjo = np.concatenate(loop_acc_cw_index_mjo)

loop_acc_vit_pv = np.concatenate(loop_acc_vit_pv)
loop_acc_index_pv = np.concatenate(loop_acc_index_pv)
loop_acc_vit_mjo = np.concatenate(loop_acc_vit_mjo)
loop_acc_index_mjo = np.concatenate(loop_acc_index_mjo)


acc_cw_vit_pv = np.mean(loop_acc_cw_vit_pv, axis=0)
acc_cw_index_pv = np.mean(loop_acc_cw_index_pv, axis=0)
acc_cw_vit_mjo = np.mean(loop_acc_cw_vit_mjo, axis=0)
acc_cw_index_mjo = np.mean(loop_acc_cw_index_mjo, axis=0)

acc_vit_pv = np.mean(loop_acc_vit_pv, axis=0)
acc_index_pv = np.mean(loop_acc_index_pv, axis=0)
acc_vit_mjo = np.mean(loop_acc_vit_mjo, axis=0)
acc_index_mjo = np.mean(loop_acc_index_mjo, axis=0)

std_acc_vit_pv = np.std(loop_acc_vit_pv, axis=0)
std_acc_index_pv = np.std(loop_acc_index_pv, axis=0)
std_acc_vit_mjo = np.std(loop_acc_vit_mjo, axis=0)
std_acc_index_mjo = np.std(loop_acc_index_mjo, axis=0)

## Plots

In [None]:
# Plot class-wise accuracy over time steps
num_time_steps, num_classes = acc_pers_cw.shape
fig, axes = plt.subplots(1, num_classes+1,figsize=(16,4))
color_r = [cm_list[3], '#aa3dbd','#aa3dbd','#aa3dbd', cm_list[2], cm_list[2], cm_list[2]]
marks = ['^', 'o', 'o','o', 's', 's','s']

# Accuracy

axes[0].plot(range(num_time_steps), acc_index,label=f'LSTM-Index', color = '#aa3dbd', marker='o',  linewidth=1, alpha=0.8)
axes[0].plot(range(num_time_steps), acc_index_pv,label=f'SPV-LSTM-Index', color = '#aa3dbd', marker='1',  linewidth=2, linestyle='--')
axes[0].plot(range(num_time_steps), acc_index_mjo,label=f'MJO-LSTM-Index', color = '#aa3dbd', marker='2',  linewidth=2, linestyle=':')

axes[0].plot(range(num_time_steps), acc_vit, label=f'ViT-LSTM', color = cm_list[2], marker='s', linewidth=1, alpha=0.8)
axes[0].plot(range(num_time_steps), acc_vit_pv, label=f'u10-ViT-LSTM', color = cm_list[2], marker='1', linewidth=1, linestyle='--')
axes[0].plot(range(num_time_steps), acc_vit_mjo, label=f'olr-ViT-LSTM', color = cm_list[2], marker='2', linewidth=1, linestyle=':')
axes[0].set_xticks(range(num_time_steps))
axes[0].set_xticklabels(range(1,7))
axes[0].set_title(f'All', fontsize=16)
axes[0].set_ylim(-0.05, 0.9)
axes[0].set_ylabel('Accuracy', fontsize=14)


for class_idx in range(1, num_classes+1):
    ax = axes[class_idx]
    class_idx = class_idx - 1

    ax.plot(range(num_time_steps), acc_cw_index[:,class_idx],label=f'Index-LSTM', color = '#aa3dbd', marker='o',  linewidth=1, alpha=0.8)
    ax.plot(range(num_time_steps), acc_cw_vit[:,class_idx], label=f'ViT-LSTM', color = cm_list[2], marker='s', linewidth=1, alpha=0.8)
    
    ax.plot(range(num_time_steps), acc_cw_index_pv[:,class_idx],label=f'SPV-Index-LSTM', color = '#aa3dbd', marker='x',  linewidth=1.5, linestyle='--')
    ax.plot(range(num_time_steps), acc_cw_index_mjo[:,class_idx],label=f'MJO-Index-LSTM', color = '#aa3dbd', marker='|',  linewidth=1.5, linestyle=':')

    ax.plot(range(num_time_steps), acc_cw_vit_pv[:,class_idx], label=f'u10-ViT-LSTM', color = cm_list[2], marker='x', linewidth=1.5, linestyle='--')
    ax.plot(range(num_time_steps), acc_cw_vit_mjo[:,class_idx], label=f'olr-ViT-LSTM', color = cm_list[2], marker='|', linewidth=1.5, linestyle=':')

    ax.set_xticks(range(num_time_steps))
    ax.set_xticklabels(range(1,7))
    ax.set_title(f'{regimes[class_idx]}', fontsize=16)
    ax.set_ylim(-0.05, 1.1)
    ax.set_xticks(range(num_time_steps))
    ax.set_xticklabels(range(1,7))
    ax.set_xlabel('Lag [weeks]', fontsize=14)
    
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels,fontsize=14, loc='lower center',bbox_to_anchor=(0.5, -0.2), ncol=3)    
fig.savefig(f"{results_directory}/ablation_accuracy.png", dpi=600, bbox_inches='tight')