In [None]:
!pip install lipd cartopy PyWavelets scikit-learn eofs

In [None]:
import pickle
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import math
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import seaborn as sns
import os
import lipd
import geopandas as gpd
from shapely.geometry import Point
from google.colab import drive
import os
import pywt
from sklearn.decomposition import PCA, KernelPCA
import random
from eofs.standard import Eof
from sklearn.metrics.pairwise import rbf_kernel
import json

# OPTIONAL: Load data from local or online source
# If running on Colab and using Google Drive, uncomment below:

# from google.colab import drive
# drive.mount('/content/drive')
# os.chdir('/content/drive/MyDrive/your_project_folder')

# Otherwise, place your data in the same folder or set the correct relative path

In [None]:
with open('bam_results_normalized.json') as f:
    bam_results = json.load(f)

In [None]:
# Clone DINEOF repo (do this only once)
!git clone https://github.com/zhouweichen1992110/DINEOF-Python.git
%cd DINEOF-Python

from utils import eof_core, init_missing, normalize, denormalize

In [None]:
# 🚀 Functions for Workflow
# -------------------------------------------
def filter_by_archive_type(bam_results, archive_types):
    filtered_results = {}

    for k, v in bam_results.items():
        archive_type = v.get('archiveType', '')

        # Handle list or str
        if isinstance(archive_type, list):
            archive_type = archive_type[0]

        if isinstance(archive_type, str) and archive_type.lower() in [a.lower() for a in archive_types]:
            filtered_results[k] = v

    print(f"✅ Total {archive_types} records found: {len(filtered_results)}")
    return filtered_results

def get_year_range_from_filtered_results(filtered_results):
    all_years = []

    for rec_id, rec_data in filtered_results.items():
        years = rec_data['bam_output']['tp']
        all_years.extend(years)

    if not all_years:
        print("⚠️ No years found in filtered results!")
        return None, None

    start_year = int(min(all_years))
    end_year = int(max(all_years))

    print(f"📅 Records cover from {start_year} to {end_year}")
    return start_year, end_year

def sample_one_realization(results_dict):
    realization_records = []

    for rec_id, rec_data in results_dict.items():
        bam_output = rec_data['bam_output']
        tp = np.array(bam_output['tp'])
        Xp = np.array(bam_output['Xp_normalized'])
        n_ensembles = Xp.shape[1]
        random_ens_idx = random.randint(0, n_ensembles - 1)

        # ✅ Step 1: Grab the archiveType from rec_data
        archive_type = rec_data.get('archiveType', '')

        # ✅ Step 2: If it's a list (e.g., ['coral']), take the first element
        if isinstance(archive_type, list):
            archive_type = archive_type[0]

        # ✅ Step 3: Add it into the dataframe
        series_df = pd.DataFrame({
            'Year': tp,
            'Value': Xp[:, random_ens_idx],
            'ID': rec_id,
            'archiveType': archive_type  # ✅ Correct archiveType!
        })

        realization_records.append(series_df)

    realization_df = pd.concat(realization_records, ignore_index=True)
    return realization_df

def bin_and_aggregate(df, year_start, year_end, bin_size=5):
    # Step 1: Create the complete bin range
    all_bins = np.arange(year_start, year_end + bin_size, bin_size)

    # Step 2: Perform binning
    df['Year_Bin'] = ((df['Year'] - year_start) // bin_size) * bin_size + year_start
    aggregated_df = df.groupby(['Year_Bin', 'ID']).agg({
        'Value': 'mean',
        'archiveType': 'first'
    }).reset_index()

    # Step 3: Reindex for each record ID to ensure NaNs for missing bins
    complete_df_list = []

    for record_id, group in aggregated_df.groupby('ID'):
        complete_bins_df = pd.DataFrame({
            'Year_Bin': all_bins,
            'ID': record_id
        })

        merged = pd.merge(complete_bins_df, group, on=['Year_Bin', 'ID'], how='left')

        # Fill archiveType (if needed)
        merged['archiveType'] = merged['archiveType'].fillna(group['archiveType'].iloc[0])

        complete_df_list.append(merged)

    final_complete_df = pd.concat(complete_df_list, ignore_index=True)

    return final_complete_df

def build_intervals(year_start, year_end, interval_length, overlap, min_last_interval_length):
    step = interval_length - overlap
    intervals = []

    start = year_start
    while start + interval_length - 1 <= year_end:
        intervals.append((start, start + interval_length - 1))
        start += step

    # Handle last interval
    if intervals:
        last_start = intervals[-1][0]
        last_end = intervals[-1][1]

        if last_end < year_end:
            final_interval = (start, year_end)

            # If it's too short, merge with the previous one
            if final_interval[1] - final_interval[0] + 1 < min_last_interval_length:
                merged = (intervals[-1][0], year_end)
                intervals[-1] = merged
            else:
                intervals.append(final_interval)
    else:
        # If no intervals were added, just make one full-range interval
        intervals.append((year_start, year_end))

    return intervals

def find_best_interval(final_binned_df, year_start, year_end, bin_size, min_last_interval_length=40, overlap=10):
    best_score = -1
    best_config = None
    results = []

    for interval_length in range(50, 201, 10):
        intervals = build_intervals(year_start, year_end, interval_length, overlap, min_last_interval_length)

        valid_ids_set = set()
        interval_valid_counts = {}

        for start, end in intervals:
            interval_data = final_binned_df[
                (final_binned_df['Year_Bin'] >= start) & (final_binned_df['Year_Bin'] <= end)
            ]

            interval_years = end - start + 1
            required_bins = max(1, int((2 / 3) * (interval_years / bin_size)))

            valid_ids = {
                record_id for record_id, group in interval_data.groupby('ID')
                if group['Value'].notna().sum() >= required_bins
            }

            interval_valid_counts[f"{start}-{end}"] = len(valid_ids)
            valid_ids_set.update(valid_ids)

        num_intervals_with_valid_ids = sum(v > 0 for v in interval_valid_counts.values())
        valid_interval_proportion = num_intervals_with_valid_ids / len(intervals)
        score = valid_interval_proportion + len(valid_ids_set)

        results.append((interval_length, num_intervals_with_valid_ids, len(valid_ids_set), valid_interval_proportion, score))

        if score > best_score:
            best_score = score
            best_config = interval_length

    results_df = pd.DataFrame(results, columns=[
        'Interval_Length', 'Intervals_with_Valid_IDs', 'Unique_IDs', 'Valid_Interval_Proportion', 'Score'
    ])

    best_row = results_df[results_df['Interval_Length'] == best_config].iloc[0]
    print(f"\n✅ Best Interval Length: {best_config} years (with {overlap}-year overlap)")
    print(f" - Max Intervals with Valid IDs: {best_row['Intervals_with_Valid_IDs']}")
    print(f" - Unique Valid IDs: {best_row['Unique_IDs']}")
    print(f" - Proportion of Valid Intervals: {best_row['Valid_Interval_Proportion']:.2f}")
    print(f" - Best Score: {best_row['Score']:.2f}\n")

    print("🏅 Top 5 Interval Lengths by Score:")
    print(results_df.sort_values(by='Score', ascending=False).head(5))

    return best_config, results_df

def summarize_interval_coverage(
    final_binned_df,
    selected_interval_length,
    year_start,
    year_end,
    bin_size,
    min_last_interval_length=40,
    include_empty_intervals=True,
    overlap=10
):
    intervals = build_intervals(year_start, year_end, selected_interval_length, overlap, min_last_interval_length)

    coverage_results = []
    interval_counts_dict = {}

    for start, end in intervals:
        interval_label = f"{start}-{end}"
        interval_data = final_binned_df[
            (final_binned_df['Year_Bin'] >= start) & (final_binned_df['Year_Bin'] <= end)
        ]

        interval_years = end - start + 1
        required_bins = max(1, int((2 / 3) * (interval_years / bin_size)))

        valid_ids = {
            record_id for record_id, group in interval_data.groupby('ID')
            if group['Value'].notna().sum() >= required_bins
        }

        interval_counts_dict[interval_label] = len(valid_ids)

        for record_id in valid_ids:
            group = interval_data[interval_data['ID'] == record_id].copy()
            group['Interval'] = interval_label
            coverage_results.append(group)

    interval_valid_df = pd.concat(coverage_results, ignore_index=True) if coverage_results else pd.DataFrame()
    total_unique_ids = interval_valid_df['ID'].nunique() if not interval_valid_df.empty else 0

    print(f"\n📊 Summary for {selected_interval_length}-year Interval with {overlap}-year Overlap")
    print("=" * 50)
    print(f"🔹 Total Unique IDs Across All Intervals: {total_unique_ids}")
    print(f"🔹 Last Interval (After Merge Check): {intervals[-1][0]}-{intervals[-1][1]}\n")

    print("\n📌 Unique ID Count Per Interval:")
    for label, count in interval_counts_dict.items():
        print(f"  - {label}: {count} IDs")
    print("=" * 50)

def filter_by_interval_coverage(
    final_binned_df,
    selected_interval_length,
    year_start,
    year_end,
    bin_size,
    min_last_interval_length=40,
    overlap=10
):
    intervals = build_intervals(year_start, year_end, selected_interval_length, overlap, min_last_interval_length)

    coverage_results = []

    for start, end in intervals:
        interval_data = final_binned_df[
            (final_binned_df['Year_Bin'] >= start) & (final_binned_df['Year_Bin'] <= end)
        ]

        interval_years = end - start + 1
        required_bins = max(1, int((2 / 3) * (interval_years / bin_size)))

        valid_ids = {
            record_id for record_id, group in interval_data.groupby('ID')
            if group['Value'].notna().sum() >= required_bins
        }

        for record_id in valid_ids:
            group = interval_data[interval_data['ID'] == record_id].copy()
            group['Interval'] = f"{start}-{end}"
            coverage_results.append(group)

    return pd.concat(coverage_results, ignore_index=True) if coverage_results else pd.DataFrame()


def run_dineof_on_intervals(interval_valid_df, stopping=0.01, rounds=500, bin_size=5):
    dineof_results = []

    for interval_name, group in interval_valid_df.groupby('Interval'):
        pivot_df = group.pivot(index='Year_Bin', columns='ID', values='Value')

        # ✅ Ensure time axis is aligned with bin_size
        all_years = np.arange(
            pivot_df.index.min(),
            pivot_df.index.max() + bin_size,
            bin_size
        )
        pivot_df = pivot_df.reindex(all_years)

        # ✅ Convert to NumPy array for DINEOF processing
        data_matrix = pivot_df.to_numpy()
        valid_id_count = pivot_df.shape[1]

        # ✅ Case 1: If only one record ID, do linear interpolation
        if valid_id_count == 1:
            filled_series = pivot_df.iloc[:, 0].interpolate(method='linear', limit_direction='both')

            # Convert to long format
            filled_long_df = filled_series.reset_index().rename(
                columns={'index': 'Year_Bin', pivot_df.columns[0]: 'Value_filled'}
            )
            filled_long_df['ID'] = pivot_df.columns[0]
            filled_long_df['Interval'] = interval_name

            dineof_results.append(filled_long_df)

        # ✅ Case 2: Multiple records → run DINEOF
        elif valid_id_count >= 2:
            max_possible_eof = min(data_matrix.shape)
            best_eof_index = min(2, max_possible_eof - 1)

            if max_possible_eof <= 1:
                print(f"⚠️ Not enough data for DINEOF in interval {interval_name}. Skipping.")
                continue

            # Normalize and initialize missing values
            dataNorm, norm_params = normalize(data_matrix, 'meanrows', 'stdrows')
            dataInit, mask = init_missing(dataNorm, 'column')

            # Apply DINEOF
            filled_data, _, _ = eof_core(
                dataInit,
                mask,
                best_eof_index,
                stop_criterion=stopping,
                max_iterations=rounds
            )

            # Denormalize back to original scale
            filled_data = denormalize(filled_data, 'std', norm_params[1], 'mean', norm_params[0])

            # Convert back to DataFrame
            filled_df = pd.DataFrame(filled_data, index=pivot_df.index, columns=pivot_df.columns)
            filled_long_df = filled_df.reset_index().melt(
                id_vars='Year_Bin',
                var_name='ID',
                value_name='Value_filled'
            )
            filled_long_df['Interval'] = interval_name

            dineof_results.append(filled_long_df)

    # ✅ Combine results into one dataframe
    dineof_df = pd.concat(dineof_results, ignore_index=True) if dineof_results else pd.DataFrame()

    # ✅ Fallback if DINEOF failed completely
    if not dineof_df.empty and dineof_df['Value_filled'].isna().all():
        print("⚠️ DINEOF failed completely! Using raw values instead.")
        dineof_df['Value_filled'] = dineof_df['Value']

    return dineof_df

def perform_standard_pca(dineof_filled_df, realization_num):
    pca_results = []
    explained_variances = {}

    for interval, group in dineof_filled_df.groupby('Interval'):
        pivot_df = group.pivot(index='Year_Bin', columns='ID', values='Value_filled')

        # Skip intervals with no valid records
        if pivot_df.shape[1] == 0:
            continue

        # Single-record interval: keep raw values as PC1
        if pivot_df.shape[1] == 1:
            pca_df = pd.DataFrame({
                'Year_Bin': pivot_df.index,
                'PC1': pivot_df.iloc[:, 0].values,
                'Interval': interval,
                'Realization': realization_num
            }).reset_index()
            pca_results.append(pca_df)
            continue

        # Standard PCA for multiple records
        standardized_data = (pivot_df - pivot_df.mean()) / pivot_df.std()

        pca = PCA()
        pcs = pca.fit_transform(standardized_data)

        explained_variances[interval] = pca.explained_variance_ratio_

        pca_df = pd.DataFrame(pcs, columns=[f"PC{i+1}" for i in range(pcs.shape[1])])
        pca_df['Year_Bin'] = pivot_df.index
        pca_df['Interval'] = interval
        pca_df['Realization'] = realization_num

        pca_results.append(pca_df)

    combined_pca_df = pd.concat(pca_results, ignore_index=True) if pca_results else pd.DataFrame()
    combined_pca_df = combined_pca_df.sort_values(by='Year_Bin')

    return combined_pca_df, explained_variances


def perform_kernel_pca(dineof_filled_df, realization_num, gamma=0.1):
    kernel_pca_results = []
    explained_variances = {}

    for interval, group in dineof_filled_df.groupby('Interval'):
        pivot_df = group.pivot(index='Year_Bin', columns='ID', values='Value_filled')

        # Skip intervals with no valid records
        if pivot_df.shape[1] == 0:
            continue

        # Single-record interval: keep raw values as KPC1
        if pivot_df.shape[1] == 1:
            kpca_df = pivot_df.reset_index()[['Year_Bin']]
            kpca_df['KPC1'] = pivot_df.iloc[:, 0].values
            kpca_df['Interval'] = interval
            kpca_df['Realization'] = realization_num
            kernel_pca_results.append(kpca_df)
            continue

        # Kernel PCA for multiple records
        standardized_data = (pivot_df - pivot_df.mean()) / pivot_df.std()

        kpca = KernelPCA(
            n_components=min(standardized_data.shape),
            kernel='rbf',
            gamma=gamma,
            fit_inverse_transform=True
        )
        pcs = kpca.fit_transform(standardized_data)

        # Manually calculate explained variance ratio
        K = rbf_kernel(standardized_data, gamma=gamma)
        n_samples = K.shape[0]
        one_n = np.ones((n_samples, n_samples)) / n_samples
        K_centered = K - one_n @ K - K @ one_n + one_n @ K @ one_n

        eigvals, _ = np.linalg.eigh(K_centered)
        eigvals_sorted = np.sort(eigvals)[::-1]

        explained_variance_ratio = eigvals_sorted / eigvals_sorted.sum()
        explained_variances[interval] = explained_variance_ratio

        kpca_df = pd.DataFrame(pcs, columns=[f"KPC{i+1}" for i in range(pcs.shape[1])])
        kpca_df['Year_Bin'] = pivot_df.index
        kpca_df['Interval'] = interval
        kpca_df['Realization'] = realization_num

        kernel_pca_results.append(kpca_df)

    combined_kpca_df = pd.concat(kernel_pca_results, ignore_index=True) if kernel_pca_results else pd.DataFrame()
    combined_kpca_df = combined_kpca_df.sort_values(by='Year_Bin')

    return combined_kpca_df, explained_variances


def perform_wavelet_pca(dineof_filled_df, realization_num, wavelet='db4', level=1):
    wavelet_pca_results = []
    explained_variances = {}

    for interval, group in dineof_filled_df.groupby('Interval'):
        pivot_df = group.pivot(index='Year_Bin', columns='ID', values='Value_filled')

        # Skip intervals with no valid records
        if pivot_df.shape[1] == 0:
            continue

        # Single-record interval: keep raw values as WPC1
        if pivot_df.shape[1] == 1:
            wavelet_pca_df = pivot_df.reset_index()[['Year_Bin']]
            wavelet_pca_df['WPC1'] = pivot_df.iloc[:, 0].values
            wavelet_pca_df['Interval'] = interval
            wavelet_pca_df['Realization'] = realization_num
            wavelet_pca_results.append(wavelet_pca_df)
            continue

        # Wavelet PCA for multiple records
        wavelet_filtered = []
        for col in pivot_df.columns:
            ts = pivot_df[col].values
            if np.isnan(ts).any():
                continue

            coeffs = pywt.wavedec(ts, wavelet, level=level)
            coeffs_filtered = [coeffs[0]] + [np.zeros_like(c) for c in coeffs[1:]]
            reconstructed_ts = pywt.waverec(coeffs_filtered, wavelet)[:len(ts)]
            wavelet_filtered.append(reconstructed_ts)

        wavelet_matrix = np.array(wavelet_filtered).T

        standardized_data = (wavelet_matrix - wavelet_matrix.mean(axis=0)) / wavelet_matrix.std(axis=0)

        pca = PCA()
        pcs = pca.fit_transform(standardized_data)

        explained_variances[interval] = pca.explained_variance_ratio_

        wavelet_pca_df = pd.DataFrame(pcs, columns=[f"WPC{i+1}" for i in range(pcs.shape[1])])
        wavelet_pca_df['Year_Bin'] = pivot_df.index
        wavelet_pca_df['Interval'] = interval
        wavelet_pca_df['Realization'] = realization_num

        wavelet_pca_results.append(wavelet_pca_df)

    combined_wavelet_pca_df = pd.concat(wavelet_pca_results, ignore_index=True) if wavelet_pca_results else pd.DataFrame()
    combined_wavelet_pca_df = combined_wavelet_pca_df.sort_values(by='Year_Bin')

    return combined_wavelet_pca_df, explained_variances

def set_global_seed(seed):
    random.seed(seed)
    np.random.seed(seed)

def run_archive_realizations(
    archive_results,
    n_realizations=200,
    selected_interval_length=None,
    year_start=None,
    year_end=None,
    bin_size=5,
    seed_value=42,
    overlap=10  # ✅ New: Overlap parameter
):
    set_global_seed(seed_value)

    if year_start is None or year_end is None:
        raise ValueError("year_start and year_end must be specified!")

    interval_length_to_use = selected_interval_length

    all_pca_results = []
    all_kpca_results = []
    all_wavelet_pca_results = []

    all_pca_evrs = []
    all_kpca_evrs = []
    all_wavelet_pca_evrs = []

    for realization_num in range(1, n_realizations + 1):
        set_global_seed(seed_value + realization_num)

        print(f"\n🎲 Realization {realization_num}/{n_realizations} | Interval Length: {interval_length_to_use} years | Overlap: {overlap} years")

        sampled_df = sample_one_realization(archive_results)

        binned_results = [
            bin_and_aggregate(group, year_start=year_start, year_end=year_end, bin_size=bin_size)
            for record_id, group in sampled_df.groupby('ID')
        ]
        final_binned_df = pd.concat(binned_results, ignore_index=True)

        interval_valid_df = filter_by_interval_coverage(
            final_binned_df,
            selected_interval_length=interval_length_to_use,
            year_start=year_start,
            year_end=year_end,
            bin_size=bin_size,
            overlap=overlap  # ✅ Propagate overlap
        )

        if interval_valid_df.empty:
            print("⚠️ No valid data after interval filtering. Skipping realization.")
            continue

        dineof_filled_df = run_dineof_on_intervals(interval_valid_df, bin_size=bin_size)

        if dineof_filled_df.empty:
            print("⚠️ DINEOF failed to fill data. Skipping realization.")
            continue

        pca_df, pca_evrs = perform_standard_pca(dineof_filled_df, realization_num)
        all_pca_results.append(pca_df)
        all_pca_evrs.append(pca_evrs)

        kpca_df, kpca_evrs = perform_kernel_pca(dineof_filled_df, realization_num)
        all_kpca_results.append(kpca_df)
        all_kpca_evrs.append(kpca_evrs)

        wavelet_df, wavelet_evrs = perform_wavelet_pca(dineof_filled_df, realization_num)
        all_wavelet_pca_results.append(wavelet_df)
        all_wavelet_pca_evrs.append(wavelet_evrs)

        print(f"✅ Finished Realization {realization_num}")

    return {
        "pca_results": all_pca_results,
        "kpca_results": all_kpca_results,
        "wavelet_pca_results": all_wavelet_pca_results,
        "pca_evrs": all_pca_evrs,
        "kpca_evrs": all_kpca_evrs,
        "wavelet_pca_evrs": all_wavelet_pca_evrs
    }

In [None]:
filtered_results = bam_results

year_start = 1551
year_end = 2016
bin_size = 1
overlap = 10

realization_df = sample_one_realization(filtered_results)
binned_df = bin_and_aggregate(realization_df, year_start=year_start, year_end=year_end, bin_size=bin_size)
best_interval, interval_results_df = find_best_interval(
    binned_df,
    year_start=year_start,
    year_end=year_end,
    bin_size=bin_size,
    overlap=overlap
)

In [None]:
selected_interval_length = 150

summarize_interval_coverage(
    final_binned_df=binned_df,
    selected_interval_length=selected_interval_length,
    year_start=year_start,
    year_end=year_end,
    bin_size=bin_size,
    include_empty_intervals=True,
    overlap=overlap
)

interval_valid_df = filter_by_interval_coverage(
    binned_df,
    selected_interval_length=selected_interval_length,
    year_start=year_start,
    year_end=year_end,
    bin_size=bin_size,
    overlap=overlap
)

In [None]:
n_realizations = 200
seed_value = 42
set_global_seed(seed_value)


results = run_archive_realizations(
    archive_results=filtered_results,
    n_realizations=n_realizations,
    selected_interval_length=selected_interval_length,
    year_start=year_start,
    year_end=year_end,
    bin_size=bin_size,
    seed_value=seed_value,
    overlap=overlap
)

In [None]:
def flip_signs_for_realizations(results_list, pc_column='PC1'):
    aligned_results = []

    for realization_idx, realization_df in enumerate(results_list):
        print(f"\n🧠 Processing Realization {realization_idx + 1}")
        df = realization_df.copy()

        interval_order = sorted(df['Interval'].unique(), key=lambda x: int(x.split('-')[0]), reverse=True)

        # Initialize reference with the most recent interval
        ref_interval = interval_order[0]
        ref_df = df[df['Interval'] == ref_interval].set_index('Year_Bin')
        print(f"📌 Reference interval: {ref_interval}")

        for i in range(1, len(interval_order)):
            curr_interval = interval_order[i]
            curr_df = df[df['Interval'] == curr_interval].set_index('Year_Bin')

            overlap_years = ref_df.index.intersection(curr_df.index)
            overlap_years = sorted(overlap_years)

            print(f"\n🔍 Comparing {curr_interval} against reference {ref_interval}")
            print(f"   Overlapping years: {overlap_years}")

            if len(overlap_years) >= 3:
                ref_values = ref_df.loc[overlap_years, pc_column].values
                curr_values = curr_df.loc[overlap_years, pc_column].values

                print(f"   {ref_interval} {pc_column} values: {ref_values}")
                print(f"   {curr_interval} {pc_column} values: {curr_values}")

                corr = np.corrcoef(ref_values, curr_values)[0, 1]
                print(f"   Correlation: {corr:.3f}")

                if not np.isnan(corr) and corr < 0:
                    print(f"   🔁 Flipping signs for {curr_interval}")
                    df.loc[df['Interval'] == curr_interval, pc_column] *= -1
                else:
                    print(f"   ✅ No flip needed for {curr_interval}")
            else:
                print(f"   ⚠️ Not enough overlapping years ({len(overlap_years)}). Skipping.")

            # ✅ Update reference for next iteration
            ref_interval = curr_interval
            ref_df = df[df['Interval'] == curr_interval].set_index('Year_Bin')

        aligned_results.append(df)

    return aligned_results


def stitch_intervals_keep_recent(df_list, pc_column='PC1'):
    stitched_records = []

    for df in df_list:
        # Sort intervals from most recent to oldest
        interval_order = sorted(df['Interval'].unique(), key=lambda x: int(x.split('-')[0]), reverse=True)

        # Track which years have already been used
        used_years = set()
        records = []

        for interval in interval_order:
            interval_df = df[df['Interval'] == interval].copy()

            # Only keep rows with Year_Bin not already used
            interval_df = interval_df[~interval_df['Year_Bin'].isin(used_years)]

            # Record used years
            used_years.update(interval_df['Year_Bin'].tolist())

            records.append(interval_df[['Year_Bin', 'Realization', pc_column]])

        combined_df = pd.concat(records, ignore_index=True)
        stitched_records.append(combined_df)

    final_df = pd.concat(stitched_records, ignore_index=True)
    final_df = final_df.sort_values(by=['Realization', 'Year_Bin'])

    return final_df

def normalize_column_by_realization(df, value_col):
    df[value_col + '_norm'] = df.groupby('Realization')[value_col].transform(
        lambda x: (x - x.mean()) / x.std(ddof=0)
    )
    return df

In [None]:
aligned_pca = flip_signs_for_realizations(results['pca_results'], pc_column='PC1')
aligned_kpca = flip_signs_for_realizations(results['kpca_results'], pc_column='KPC1')
aligned_wpca = flip_signs_for_realizations(results['wavelet_pca_results'], pc_column='WPC1')

In [None]:
stitched_pca_df = stitch_intervals_keep_recent(aligned_pca, pc_column='PC1')
stitched_kpca_df = stitch_intervals_keep_recent(aligned_kpca, pc_column='KPC1')
stitched_wpca_df = stitch_intervals_keep_recent(aligned_wpca, pc_column='WPC1')

stitched_pca_df = normalize_column_by_realization(stitched_pca_df, 'PC1')
stitched_kpca_df = normalize_column_by_realization(stitched_kpca_df, 'KPC1')
stitched_wpca_df = normalize_column_by_realization(stitched_wpca_df, 'WPC1')

In [None]:
# Merge and compute percentiles
merged = pd.DataFrame({'Year_Bin': stitched_pca_df['Year_Bin'].unique()})
for df, label in zip(
    [stitched_pca_df, stitched_kpca_df, stitched_wpca_df],
    ['PC1_norm', 'KPC1_norm', 'WPC1_norm']
):
    temp = df.groupby('Year_Bin')[label].quantile([0.05, 0.5, 0.95]).unstack()
    temp.columns = [f"{label}_p5", f"{label}_p50", f"{label}_p95"]
    merged = pd.merge(merged, temp, left_on='Year_Bin', right_index=True, how='left')

# Plot
plt.figure(figsize=(16, 8))
for label, color in zip(['PC1_norm', 'KPC1_norm', 'WPC1_norm'], ['red', 'green', 'blue']):
    x = merged['Year_Bin'].astype(float).values
    y_median = merged[f'{label}_p50'].astype(float).values
    y_low = merged[f'{label}_p5'].astype(float).values
    y_high = merged[f'{label}_p95'].astype(float).values

    plt.plot(x, y_median, label=label.split('_')[0] + ' Median', color=color, linewidth=2)
    plt.fill_between(x, y_low, y_high, color=color, alpha=0.2)

plt.title('Median and 5th–95th Percentile Range of PC1 from PCA, Kernel PCA, and Wavelet PCA')
plt.xlabel('Year')
plt.ylabel('Normalized PC1 Value')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# ✅ Extract explained variance ratios from results dictionary
all_pca_evrs = results['pca_evrs']
all_kpca_evrs = results['kpca_evrs']
all_wavelet_pca_evrs = results['wavelet_pca_evrs']

# ✅ Define a helper function to compute mean explained variance for PC1
def compute_mean_pc1_variance(evrs_list, method_name='PCA'):
    pc1_variances = []

    for evr_dict in evrs_list:
        for interval, evrs in evr_dict.items():
            if isinstance(evrs, (list, np.ndarray)) and len(evrs) > 0:
                pc1_variances.append(evrs[0])

    if pc1_variances:
        mean_pc1_variance = np.mean(pc1_variances)
        print(f"✅ Mean explained variance ratio for {method_name} (Component 1): {mean_pc1_variance:.2%}")
        return mean_pc1_variance
    else:
        print(f"⚠️ No explained variance data found for {method_name}.")
        return np.nan

# ✅ Run for each PCA method
mean_pc1_variance   = compute_mean_pc1_variance(all_pca_evrs, method_name='Standard PCA')
mean_kpc1_variance  = compute_mean_pc1_variance(all_kpca_evrs, method_name='Kernel PCA')
mean_wpc1_variance  = compute_mean_pc1_variance(all_wavelet_pca_evrs, method_name='Wavelet PCA')
