In [7]:
from tqdm.auto import tqdm 
import pandas as pd
import numpy as np
import xarray as xr
import netCDF4 as nf
from netCDF4 import Dataset
%matplotlib inline
import glob
import seaborn as sns
import matplotlib.pyplot as plt
import ast,gc,pickle
from copy import deepcopy
import os
# Custom packages
import read_config
from util.data_process import read_vars, proc_dataset, miss
from util.models import performance_scores,train_baseline,causal_settings,train_PC1

In [8]:
# Read configuration file
config_set = read_config.read_config()
# Define Target
if int(config_set['target_lag'])==20:
    target='delv120'
if int(config_set['target_lag'])==16:
    target='delv96'
if int(config_set['target_lag'])==12:
    target='delv72'
if int(config_set['target_lag'])==8:
    target='delv48'
if int(config_set['target_lag'])==4:
    target='delv24'
seeds = np.arange(0,7,1) #np.arange(100,131,1)

In [9]:
target

'delv24'

### This Turtorial script is for creating Figures 2 and 3 from the paper and analysing the results*.pkl files

In [10]:
var_names =  performance_scores.scores_seeds(seed=0,target=target,lag=int(config_set['target_lag']),exp='SHIPSERA5_noassum').read_stored()['var_names']

In [11]:
len(var_names)

215

In [12]:
list(var_names)

['delv24',
 'pmin',
 'wind10',
 'out_t250',
 'out_t200',
 'spdx',
 'out_mean_midrhum',
 'POT',
 'POT2',
 'PER',
 'VPER',
 'SHDC',
 'VSHR',
 'LHRD',
 'EPOS',
 'clat',
 'tadv',
 'sdir',
 'd200',
 'z850',
 'twnd850',
 'div_100',
 'div_250',
 'div_300',
 'div_400',
 'div_500',
 'div_700',
 'div_850',
 'div_1000',
 'eqt1000',
 'eqt300',
 'eqt400',
 'eqt500',
 'eqt700',
 'eqt850',
 'sst',
 'vort_100',
 'vort_150',
 'vort_200',
 'vort_250',
 'vort_300',
 'vort_400',
 'vort_500',
 'vort_700',
 'vort_1000',
 'pvor_100',
 'pvor_150',
 'pvor_200',
 'pvor_250',
 'pvor_300',
 'pvor_400',
 'pvor_500',
 'pvor_700',
 'pvor_850',
 'pvor_1000',
 'rhum_100',
 'rhum_150',
 'rhum_200',
 'rhum_250',
 'rhum_300',
 'rhum_400',
 'rhum_850',
 'rhum_1000',
 'gpot_100',
 'gpot_150',
 'gpot_200',
 'gpot_250',
 'gpot_300',
 'gpot_400',
 'gpot_500',
 'gpot_700',
 'gpot_800',
 'gpot_850',
 'gpot_1000',
 'temp_100',
 'temp_150',
 'temp_300',
 'temp_400',
 'temp_500',
 'temp_700',
 'temp_850',
 'temp_1000',
 'outdiv_10

# With causal

In [13]:
score_causal = []
for seed in tqdm(seeds):
    causal_results = performance_scores.scores_seeds(seed=seed,
                                                     target=target,
                                                     lag=int(config_set['target_lag']),
                                                     exp='SHIPSERA5_noassum').run_score_causalFS()
    score_causal.append(causal_results)
    del causal_results
    gc.collect()

  0%|          | 0/7 [00:00<?, ?it/s]

In [14]:
r2_train_causalFS, r2_valid_causalFS, r2_test_causalFS, shapez_causalFS = [],[],[],[]
for i in range(len(score_causal)):
    r2_train_causalFS.append([(score_causal[i][j]['scoreboard']['train']['r2']) for j in range(len(score_causal[i]))])
    r2_valid_causalFS.append([(score_causal[i][j]['scoreboard']['valid']['r2']) for j in range(len(score_causal[i]))])
    r2_test_causalFS.append([(score_causal[i][j]['scoreboard']['test']['r2']) for j in range(len(score_causal[i]))])
    shapez_causalFS.append([(score_causal[i][j]['X']['test'].shape[1]) for j in range(len(score_causal[i]))])

In [19]:
alpha_levels = [0.0001, 0.00015 ,0.001,0.0015,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,
                          0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6]
expname='SHIPS_noassumptions+ERA5'

## Figure 2

In [20]:
num_folds = len(score_causal)
num_alphas = len(score_causal[0])

r2_valid = np.zeros((num_folds, num_alphas))
for fold_idx, fold_result in enumerate(score_causal):
    for alpha_idx, result in enumerate(fold_result):
        sb = result['scoreboard']
        r2_valid[fold_idx, alpha_idx] = sb['valid']['r2']

best_fold_idx = np.argmax(np.max(r2_valid, axis=1))
best_r2 = np.max(r2_valid[best_fold_idx])
print(f"Best fold (highest max valid R²): {best_fold_idx} with R² = {best_r2:.4f}")


Best fold (highest max valid R²): 0 with R² = 0.1955


#### Example of feature selection behavior across PC significance thresholds in one cross-validation fold.

Panel A: R² values for train (blue), validation (green), and test (red) datasets as a function of the number of selected variables. The vertical dashed line marks the point with the best validation R².

Panel B: “Abacus” plot showing which variables are selected at different thresholds. Each row corresponds to a variable, with colors indicating physical groupings (e.g., SHIPS predictors, shear, humidity, temperature, geopotential, vorticity). Highlighted markers indicate variables consistently identified as important across settings.

Extra axis: The secondary x-axis connects the number of selected predictors to the corresponding pc_alpha values, directly linking statistical significance thresholds to predictor set size.

Together, the panels and axis illustrate how causal feature selection balances predictive skill (Panel A), variable inclusion patterns (Panel B), and the statistical testing level (pc_alpha). This highlights both robust predictors and the sensitivity of selection to the independence test.

In [21]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from collections import defaultdict
import matplotlib.ticker as mticker

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Nimbus Roman', 'Times', 'C059-Roman', 'P052-Roman', 'DejaVu Serif']
plt.rcParams['mathtext.fontset'] = 'stix'

def assign_group(var):
    var = var.lower()
    ships_vars = {
        'pmin', 'wind10', 'out_t250', 'out_t200', 'spdx', 'out_mean_midrhum',
        'pot', 'pot2', 'per', 'vper', 'shdc', 'vshr', 'lhrd', 'epos',
        'clat', 'tadv', 'sdir', 'd200', 'z850', 'twnd850'
    }
    if var in ships_vars:
        return 'SHIPS'
    elif var.startswith('shear'):
        return 'Shear'
    elif 'rhum' in var or 'rh_' in var or 'rh' in var:
        return 'Humidity'
    elif 'temp' in var or 'tanom' in var or 'tgrad' in var or 'eqt' in var:
        return 'Temperature'
    elif 'gpot' in var or 'geop' in var:
        return 'Geopotential'
    elif 'vort' in var or 'pvor' in var or 'outpvor' in var or 'outvort' in var:
        return 'Vorticity'
    else:
        return 'Other'

group_colors = {
    'SHIPS':      '#2A5D8F',
    'Shear':      '#bcbd22',
    'Humidity':   '#00BFFF',
    'Temperature':'#B03060',
    'Geopotential':'#9467BD',
    'Vorticity':  '#3B7A57',
    'Other':      '#A6761D'
}

best_vars = {
    "shear_1000_850",
    "outrhum_1000",
    "shear_850_300",
    "rh_0_500_1000",
    "shear_1000_850.1",
    "outpvor_500"
}

def rank_variables_by_importance(score_causal, fold_idx):
    var_scores = defaultdict(list)
    num_alphas = len(score_causal[fold_idx])
    for alpha_idx in range(num_alphas):
        corrrank = score_causal[fold_idx][alpha_idx]['corrrank']
        n = len(corrrank)
        for i, var in enumerate(corrrank):
            importance = n - i
            var_scores[var].append(importance)
    avg_scores = {var: np.mean(scores) for var, scores in var_scores.items()}
    sorted_vars = sorted(avg_scores.items(), key=lambda x: -x[1])
    return [var for var, _ in sorted_vars]

def get_varlist_lengths(score_causal, fold_idx):
    num_alphas = len(score_causal[fold_idx])
    var_counts = np.zeros(num_alphas, dtype=int)
    for alpha_idx, result in enumerate(score_causal[fold_idx]):
        varlist = result['corrrank']
        var_counts[alpha_idx] = len(varlist)
    return var_counts

def plot_one_fold_with_abacus(score_causal, alpha_levels, target_name, experiment_name, fold_idx=None):
    os.makedirs("figures", exist_ok=True)

    num_folds = len(score_causal)
    num_alphas = len(score_causal[0])

    # Select best fold if not provided
    r2_valid = np.zeros((num_folds, num_alphas))
    for f in range(num_folds):
        for a in range(num_alphas):
            r2_valid[f, a] = score_causal[f][a]['scoreboard']['valid']['r2']
    if fold_idx is None:
        fold_idx = np.argmax(np.max(r2_valid, axis=1))
    print(f"Plotting fold: {fold_idx}")

    num_vars_list = get_varlist_lengths(score_causal, fold_idx)
    print("num_vars_list:", num_vars_list)

    r2_train = []
    r2_valid_fold = []
    r2_test = []
    for a in range(num_alphas):
        sb = score_causal[fold_idx][a]['scoreboard']
        r2_train.append(sb['train']['r2'])
        r2_valid_fold.append(sb['valid']['r2'])
        r2_test.append(sb['test']['r2'])

    best_valid_idx = np.argmax(r2_valid_fold)
    best_num_vars = num_vars_list[best_valid_idx]

    ranked_vars = rank_variables_by_importance(score_causal, fold_idx)
    ordered_vars = []
    seen = set()
    for a in range(num_alphas):
        for v in score_causal[fold_idx][a]['corrrank']:
            if v not in seen:
                ordered_vars.append(v)
                seen.add(v)
    ranked_vars = [v for v in ordered_vars if v in ranked_vars]
    var_colors = {v: group_colors.get(assign_group(v), 'grey') for v in ranked_vars}

    fig = plt.figure(figsize=(13, 17))
    gs = gridspec.GridSpec(3, 1, height_ratios=[3, 6, 1], hspace=0.15)
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1], sharex=ax1)
    ax3 = fig.add_subplot(gs[2], sharex=ax1)

    # === Panel A: R² Curves ===
    ax1.plot(num_vars_list, r2_train, '-o', color='blue', label='Train')
    ax1.plot(num_vars_list, r2_valid_fold, '-o', color='green', label='Validation')
    ax1.plot(num_vars_list, r2_test, '-o', color='red', label='Test')
    ax1.axvline(best_num_vars, color='purple', linestyle='--', lw=1.8, label='Best Valid R²')
    ax1.set_ylabel('R²')
    ax1.set_ylim(0.1, 0.2)
    ax1.legend(title=f'Fold {fold_idx}')
    ax1.grid(True, linestyle=':', alpha=0.5)

    # === Panel B: Variable Abacus ===
    for i, var in enumerate(ranked_vars):
        y_offset = (i + 1) * 0.7
        for a in range(num_alphas):
            if var in score_causal[fold_idx][a]['corrrank']:
                x = num_vars_list[a]
                ax2.plot(x, y_offset, 'o', color=var_colors[var], markersize=7, alpha=0.8)
                if var in best_vars:
                    ax2.plot(x, y_offset, 'o', markerfacecolor='none', markeredgecolor='black',
                             markeredgewidth=1.8, markersize=7)
        ax2.text(max(num_vars_list) + 1, y_offset, var, va='center',
                 fontsize=10, fontweight='bold', color=var_colors[var])
    ax2.axvline(best_num_vars, color='purple', linestyle='--', lw=1.8)
    ax2.set_yticklabels([])
    ax2.grid(False)

    # Add legend handles for variable groups here:
    group_set = sorted(set(assign_group(v) for v in ranked_vars))
    handles = [
        plt.Line2D([0], [0], marker='o', color='w', label=g,
                   markerfacecolor=group_colors[g], markersize=8)
        for g in group_set
    ]
    ax2.legend(handles=handles, loc='upper left', fontsize=10, title='Variable Groups', title_fontsize=12)

    # === Panel C: pc_alpha values ===
    var_to_alpha = {}
    for a, alpha_val in enumerate(alpha_levels):
        x = num_vars_list[a]
        if (x not in var_to_alpha) or (alpha_val > var_to_alpha[x]):
            var_to_alpha[x] = alpha_val
    for x, alpha_val in sorted(var_to_alpha.items()):
        ax3.plot(x, 0, 's', color='black', markersize=8)
        ax3.text(x, 0.3, f'{alpha_val:.3g}', ha='center', fontsize=9, rotation=45)
    ax3.set_ylim(-1, 0.5)
    ax3.axis('off')
    ax3.set_title('pc_alpha values w.r.t Number of Variables')

    # Flexible xlim for all panels
    x_min = min(num_vars_list)
    x_max = max(num_vars_list)
    for ax in [ax1, ax2, ax3]:
        ax.set_xlim([x_min - 1, x_max + 4.8])
        tick_step = 2
        start_tick = int(np.floor(x_min / tick_step) * tick_step)
        end_tick = int(np.ceil(x_max / tick_step) * tick_step)
        ax.set_xticks(range(start_tick, end_tick + tick_step, tick_step))
        ax.xaxis.set_major_formatter(mticker.FormatStrFormatter('%d'))

    fig.subplots_adjust(left=0.1, right=0.95, top=0.93, bottom=0.12)
    outpath = f"figures/{target_name}_{experiment_name}_fold{fold_idx}_R2_abacus_pcalpha.png"
    plt.savefig(outpath, dpi=300)
    plt.close()
    print(f"Saved: {outpath}")


In [22]:
plot_one_fold_with_abacus(score_causal, alpha_levels,target_name='DELV24',experiment_name='SHIPS_noassumptionERA5',fold_idx=3)

Plotting fold: 3
num_vars_list: [ 5  5  5  5  7  8  8 10 11 11 11 11 12 13 13 14 14 15 16 18 19 19 22 23]
Saved: figures/DELV24_SHIPS_noassumptionERA5_fold3_R2_abacus_pcalpha.png


## FIgure 3a

#### Variable Ranking Across Folds

This bar plot shows how often each predictor (excluding standard SHIPS variables) is selected across folds at the best pc_alpha. Taller bars indicate variables consistently chosen, with the red dashed line marking a cutoff of ≥3 folds to highlight robust predictors.

In [23]:
import matplotlib.pyplot as plt
from collections import Counter
import os

# === AMS-style serif font ===
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Nimbus Roman', 'Times', 'C059-Roman', 'P052-Roman', 'DejaVu Serif']
plt.rcParams['mathtext.fontset'] = 'stix'

def plot_union_variable_counts(score_causal, alpha_levels, target_name, experiment_name):
    excluded_vars = set([
        'DELV24', 'pmin', 'wind10', 'out_t250', 'out_t200', 'spdx',
        'out_mean_midrhum', 'POT', 'POT2', 'PER', 'VPER', 'SHDC',
        'VSHR', 'LHRD', 'EPOS', 'clat', 'tadv', 'sdir', 'd200',
        'z850', 'twnd850',
    ])

    num_folds = len(score_causal)
    best_vars_per_fold = []

    for fold_idx in range(num_folds):
        num_alphas = len(score_causal[fold_idx])
        best_alpha_idx = max(range(num_alphas), key=lambda a: score_causal[fold_idx][a]['scoreboard']['valid']['r2'])
        best_vars = set(score_causal[fold_idx][best_alpha_idx]['corrrank'])
        filtered_vars = best_vars - excluded_vars
        best_vars_per_fold.append(filtered_vars)

    all_vars = set.union(*best_vars_per_fold)

    var_fold_counts = Counter()
    for var in all_vars:
        count = sum(var in fold_vars for fold_vars in best_vars_per_fold)
        var_fold_counts[var] = count

    # Filter out variables with counts < 2
    filtered_counts = {var: count for var, count in var_fold_counts.items() if count >= 2}
    if not filtered_counts:
        print("No variables with count >= 2 to plot.")
        return

    sorted_vars = sorted(filtered_counts.items(), key=lambda x: (-x[1], x[0]))
    vars_list, counts = zip(*sorted_vars)

    # Very light color mapping for fold counts
    count_color_map = {
        7: '#B3E5FC',   # Very light cyan
        6: '#ADD8E6',   # Light blue
        5: '#B0C4DE',   # Light steel blue
        4: '#D3D3F3',   # Very light pastel blue
        3: '#E6E6FA',   # Lavender
        2: '#FFDDEE',   # Light pink
        1: '#FFE4E1',   # Misty rose
    }
    bar_colors = [count_color_map.get(count, '#DDDDDD') for count in counts]  # default light gray

    fig, ax = plt.subplots(figsize=(max(10, len(vars_list)*0.6), 6))
    bars = ax.bar(range(len(vars_list)), counts, color=bar_colors, edgecolor='grey')

    # Hide x-axis tick labels
    ax.set_xticks(range(len(vars_list)))
    ax.set_xticklabels([])

    # Axis labels
    ax.set_ylabel('Number of Folds', fontsize=16)
    ax.set_xlabel('Variables (Union of best pc_alpha per fold, excluding SHIPS)', fontsize=16)

    # Annotate variable names inside bars, vertically
    for bar, var in zip(bars, vars_list):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width()/2,
            height/2,                    # Middle of the bar
            var,                          # Variable name
            ha='center', va='center',
            fontsize=13, fontweight='bold', rotation=90, color='black'
        )

    # Horizontal cutoff line
    ax.axhline(y=3, color='red', linestyle='--', linewidth=1.5, label='Selection cutoff: > 3 folds')
    ax.legend(fontsize=14)

    ax.set_ylim(0, num_folds + 0.5)
    ax.grid(axis='y', linestyle='--', alpha=0.6)

    plt.tight_layout()

    savepath = f"figures/{target_name}_{experiment_name}_variable_fold_counts_barplot.png"
    os.makedirs(os.path.dirname(savepath), exist_ok=True)
    plt.savefig(savepath, dpi=300)
    plt.close()
    print(f"Saved bar plot: {savepath}")


In [24]:
plot_union_variable_counts(score_causal, alpha_levels, target_name='DELV24', experiment_name='SHIPS_noassumption_ERA5')

Saved bar plot: figures/DELV24_SHIPS_noassumption_ERA5_variable_fold_counts_barplot.png


## Find variable Ranking

#### Finding Variable Rankings Across Feature Selection Methods

For each feature selection method (Causal PC1, correlation-based, or XAI-based), we identify the best set of variables per fold using the R²-maximizing configuration.

find_best_var_per_seed_causal extracts the top variables from the causal PC1 results.

find_best_var_per_seed_corrXAI and find_best_var_per_seed_XAI extract top predictors from correlation-based and XAI-based rankings, respectively.

count_all_varlists aggregates the variable lists across folds, computing how often each variable is selected.

This gives a fold-robust ranking of predictors for each feature selection method.

In [29]:
def find_best_var_per_seed_causal(r2,score_causal):
    causal_varlists = []
    for i in range(len(score_causal)):
        causal_varlists.append(score_causal[i][np.asarray(r2[i]).argmax()]['corrrank'])
    return causal_varlists

def find_best_var_per_seed_corrXAI(r2,score_corr):
    corr_varlists = []
    for i in range(len(score_corr)):
        corr_varlists.append(score_correlation[i]['corrrank'].sort_values(ascending=False).index[0:np.asarray(r2[i]).argmax()+1])
    return corr_varlists

def find_best_var_per_seed_XAI(r2,score_corr):
    corr_varlists = []
    for i in range(len(score_corr)):
        corr_varlists.append(score_corr[i]['XAIrank'].sort_values(ascending=False).index[0:np.asarray(r2[i]).argmax()+1])
    return corr_varlists

def count_all_varlists(varlists):
    from collections import Counter
    from functools import reduce
    freq = list(map(Counter, varlists)) 
    res = {ele: [cnt[ele] for cnt in freq] for ele in {ele for cnt in freq for ele in cnt}}
    
    dict1 = res
    for k, v in dict1.items():
        dict1[k] = reduce(lambda a, b: a+b, v)
    return pd.Series((dict1))

In [30]:
causal_varlists = find_best_var_per_seed_causal(r2_valid_causalFS,score_causal)
count_causallists = count_all_varlists(causal_varlists)

In [31]:
bestvars=count_causallists.sort_values(ascending=False)[:40] #4

In [32]:
count_causallists.sort_values(ascending=False)[:40] #12

wind10              7
sdir                7
tadv                7
outrhum_1000        7
outtemp_150         6
z850                5
rh_0_500_1000       5
shear_1000_850.1    5
div_100             3
pvor_850            3
vort_1000           3
tgrad_200_800       3
out_mean_midrhum    3
shear_1000_850      3
rhum_300            3
shear_1000_300      3
div_1000            3
div_400             2
shear_1000_200.1    2
tanom250            2
LHRD                2
shear_850_300       2
vort_700            2
vort_500            2
eqt850              2
outtemp_500         2
div_0_1000_150      2
vort_200            1
rhum_150            1
rh_0_500_400        1
outgpot_400         1
outrhum_100         1
shear_1000_500      1
outdiv_300          1
div_250             1
div_850             1
pvor_700            1
shear_850_250.1     1
outgpot_850         1
pvor_1000           1
dtype: int64

In [33]:
count_causallists.sort_values(ascending=False)[:40] #8

wind10              7
sdir                7
tadv                7
outrhum_1000        7
outtemp_150         6
z850                5
rh_0_500_1000       5
shear_1000_850.1    5
div_100             3
pvor_850            3
vort_1000           3
tgrad_200_800       3
out_mean_midrhum    3
shear_1000_850      3
rhum_300            3
shear_1000_300      3
div_1000            3
div_400             2
shear_1000_200.1    2
tanom250            2
LHRD                2
shear_850_300       2
vort_700            2
vort_500            2
eqt850              2
outtemp_500         2
div_0_1000_150      2
vort_200            1
rhum_150            1
rh_0_500_400        1
outgpot_400         1
outrhum_100         1
shear_1000_500      1
outdiv_300          1
div_250             1
div_850             1
pvor_700            1
shear_850_250.1     1
outgpot_850         1
pvor_1000           1
dtype: int64

In [35]:
import os
save_dir= '../2024_causalML_results/results/4/XAI_noassum/'
# List to store the loaded results
score_corr = []

# Load each pickle file and append to score_corr
for seed in seeds:
    file_path = os.path.join(save_dir, f'xai_results_fold_{seed}.pkl')
    with open(file_path, 'rb') as f:
        corr_results = pickle.load(f)
        score_corr.append(corr_results)

In [36]:
r2_train_XAIFS, r2_valid_XAIFS, r2_test_XAIFS = [],[],[]
for i in range(len(score_corr)):
    tempscore = score_corr[i]['scoreboard']
    trains,valids,tests=[],[],[]
    for j in range(len(tempscore)):
        trains.append(tempscore[j]['train']['r2'])
        valids.append(tempscore[j]['valid']['r2'])
        tests.append(tempscore[j]['test']['r2'])
    r2_train_XAIFS.append(trains)
    r2_valid_XAIFS.append(valids)
    r2_test_XAIFS.append(tests)

In [37]:
xai_varlists = find_best_var_per_seed_XAI(r2_valid_XAIFS,score_corr)
count_xailists = count_all_varlists(xai_varlists)

In [38]:
count_xailists.sort_values(ascending=False)[:40] #4

pvor_700            7
pwat_0_200          7
wind10              7
sdir                7
shear_1000_300      7
shear_1000_850      7
pvor_850            7
outpvor_1000        6
shear_1000_500      5
pvor_400            5
shear_850_300       5
shear_1000_850.1    5
pwat_200_400        4
tanom250            4
eqt850              4
pwat_600_800        4
pwat_800_1000       4
gpot_1000           4
outpvor_100         3
clat                3
vort_1000           3
pwat_400_600        2
spdx                2
tanom300            2
tanom500            2
temp_1000           2
outpvor_150         2
outeqt850           2
outtemp_850         2
shear_1000_300.1    2
shear_1000_700      2
POT2                2
outrhum_400         1
outpvor_500         1
eqt300              1
vort_500            1
outeqt700           1
out_t200            1
temp_400            1
POT                 1
dtype: int64

## With Correlation

In [39]:
score_correlation = []
for seed in tqdm(seeds):
    correlation_results = performance_scores.scores_seeds(seed=seed,target=target,lag=int(config_set['target_lag']),exp='SHIPSERA5_noassum').run_score_corrFS(shapez=np.asarray(miss.flatten(shapez_causalFS)).max())
    score_correlation.append(correlation_results)
    del correlation_results
    gc.collect()

  0%|          | 0/7 [00:00<?, ?it/s]

In [40]:
r2_train_corrFS, r2_valid_corrFS, r2_test_corrFS = [],[],[]
for i in range(len(score_correlation)):
    tempscore = score_correlation[i]['scoreboard']
    trains,valids,tests=[],[],[]
    for j in range(len(tempscore)):
        trains.append(tempscore[j]['train']['r2'])
        valids.append(tempscore[j]['valid']['r2'])
        tests.append(tempscore[j]['test']['r2'])
    r2_train_corrFS.append(trains)
    r2_valid_corrFS.append(valids)
    r2_test_corrFS.append(tests)

In [41]:
def get_best_fold_and_alpha(score_correlation, alpha_levels):
    best_val_r2 = -np.inf
    best_fold_idx = None
    best_alpha_idx = None

    for fold_idx, fold_result in enumerate(score_correlation):
        scoreboard_list = fold_result["scoreboard"]  # list over alphas
        for alpha_idx, result in enumerate(scoreboard_list):
            val_r2 = result["valid"]["r2"]
            if val_r2 > best_val_r2:
                best_val_r2 = val_r2
                best_fold_idx = fold_idx
                best_alpha_idx = alpha_idx

    best_alpha_value = alpha_levels[best_alpha_idx % len(alpha_levels)]
    print(f"Best Fold Index     : {best_fold_idx}")
    print(f"Best Alpha Index    : {best_alpha_idx}")
    print(f"Best Alpha Value    : {best_alpha_value}")
    print(f"Best Validation R²  : {best_val_r2:.4f}")

    return best_fold_idx, best_alpha_idx, best_alpha_value



In [42]:
best_fold, best_alpha_idx, best_alpha_val = get_best_fold_and_alpha(score_correlation, alpha_levels)

Best Fold Index     : 0
Best Alpha Index    : 27
Best Alpha Value    : 0.0015
Best Validation R²  : 0.2024


In [43]:
corr_varlists = find_best_var_per_seed_corrXAI(r2_valid_corrFS,score_correlation)
count_corrlists = count_all_varlists(corr_varlists)

In [44]:
count_corrlists.sort_values(ascending=False)[:50] #4

pvor_850            7
wind10              7
pmin                7
shear_1000_850      7
shear_850_250       7
pvor_1000           7
shear_850_300       7
gpot_850            7
pvor_700            7
gpot_800            7
div_1000            6
vort_700            6
shear_1000_700      6
shear_1000_500      6
tanom1000           6
gpot_1000           6
POT2                5
shear_1000_850.1    5
shear_1000_300      5
gpot_700            5
pvor_400            5
pvor_500            4
tanom850            4
tanom250            4
tanom700            3
LHRD                3
shear_850_300.1     3
shear_1000_200      3
vort_1000           2
shear_850_250.1     2
vort_500            2
shear_1000_200.1    1
clat                1
shear_1000_500.1    1
shear_1000_300.1    1
tanom500            1
dtype: int64