In [2]:
import pickle
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
from typing import Tuple
from pydantic import BaseModel, Field
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

In [3]:
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# load results

In [4]:
# pydantic model

class PreprocessingSteps(BaseModel):
    transformers: Tuple[
        Tuple[PCA, StandardScaler, np.ndarray, np.ndarray],
        Tuple[float, float]
    ] = Field(
        ..., 
        description="Preprocessing steps including PCA, StandardScaler, and arrays"
    )

    class Config:
        arbitrary_types_allowed = True
        json_encoders = {
            np.ndarray: lambda v: v.tolist(),
            PCA: lambda v: str(v),
            StandardScaler: lambda v: str(v)
        }


In [5]:
def load_data(dataset: str, activation: str = "relu") -> dict:
    """
    Load data from pickle files for a specific dataset.
    
    Args:
        dataset (str): Name of the dataset or specific subdirectory
    
    Returns:
        dict: A dictionary with full file paths as keys and loaded data as values
    """
    
    base_dir = "../results/datasets/real_world_data/"
    if dataset in ["circles", "spheres", "alternate_stripes"]:
        base_dir = "../results/datasets/synthetic_data"
    
    full_path = os.path.join(base_dir, dataset)
    full_path = full_path + f"/activation={activation}"

    print(full_path)
    
    data_dict = {}
    
    # Walk through the directory tree starting from the specified dataset path
    for root, dirs, files in os.walk(full_path):
        for file in files:
            # Check if the file is a pickle file
            if file.endswith('.pkl') or file.endswith('.p'):
                full_file_path = os.path.join(root, file)
                
                try:
                    # Load the pickle file
                    with open(full_file_path, "rb") as f:
                        data = pickle.load(f)
                    
                    # Store in the dictionary with full path as key
                    data_dict[full_file_path] = data
                
                except (IOError, pickle.UnpicklingError) as e:
                    print(f"Error loading {full_file_path}: {e}")
    
    return data_dict

In [6]:
def read_all_nested_files(dataset: str, activation: str = "relu", k: int = 2):
    base_dir = "../results/datasets/real_world_data/"
    if dataset in ["circles", "spheres", "alternate_stripes"]:
        base_dir = "../results/datasets/synthetic_data"

    # NEW STRUCTURE: .../<dataset>/k=<k>/activation=<activation>/...
    full_path = os.path.join(base_dir, dataset, f"k={k}", f"activation={activation}")

    file_contents = {}
    for root, dirs, files in os.walk(full_path):
        for filename in files:
            file_path = os.path.join(root, filename)

            # For training curves, results_list.p is the one you want
            if os.path.basename(file_path) == "results_list.p":
                try:
                    with open(file_path, "rb") as f:
                        file_contents[file_path] = pickle.load(f)
                except Exception as e:
                    print(f"Could not read {file_path}: {e}")

    return file_contents


# functions to collate the data

In [7]:
# get the explained variance objectives for train and validation

def extract_explained_variance(data: list[list[PreprocessingSteps]]) -> np.ndarray:
    """
    Extracts the explained variance curves from a list of runs and returns a
    padded array with NaNs for variable-length runs.

    Input structure expectation (per run):
    - data[j] is an iterable of time steps for run j
    - each element has the explained variance as the last item (index -1)
      and is typically a 1D array-like of length 2: [train, val]

    Returns:
    - np.ndarray with shape (n_runs, max_time, d), padded with np.nan
      so that ragged runs are handled consistently downstream.
    """

    runs: list[np.ndarray] = []
    max_len = 0

    for run in data:
        # Collect per-step explained variance for this run
        seq = []
        for step in range(len(run)):
            try:
                last = run[step][-1]
            except Exception:
                # If structure is unexpected, skip this step
                continue

            arr = np.array(last)
            # Ensure 1D shape
            if arr.ndim == 0:
                arr = np.array([float(arr)])
            elif arr.ndim > 1:
                arr = arr.ravel()

            seq.append(arr.astype(float, copy=False))

        if len(seq) == 0:
            continue

        seq_arr = np.vstack([s if s.ndim == 1 else s.ravel() for s in seq])
        runs.append(seq_arr)
        max_len = max(max_len, seq_arr.shape[0])

    if len(runs) == 0:
        return np.empty((0, 0, 0))

    # Infer feature dimension (e.g., 2 for [train, val])
    d = runs[0].shape[1] if runs[0].ndim == 2 else 1

    padded = np.full((len(runs), max_len, d), np.nan, dtype=float)
    for i, r in enumerate(runs):
        T = r.shape[0]
        if r.ndim == 1:
            padded[i, :T, 0] = r
        else:
            dd = min(d, r.shape[1])
            padded[i, :T, :dd] = r[:, :dd]

    return padded

# visualise

In [8]:
def plot_percentiles(partial_data, full_data, dataset: str, k: int = 2, activation: str | None = None):
    plt.figure(figsize=(10, 6))
    # Colors for different contribution types
    colors = {'partial': 'darkred', 'full': 'darkblue'}
    linestyles = ['--', '-']

    # Resolve activation (keeps your current behavior, but lets you override)
    if activation is None:
        activation_fn = 'cos' if dataset == 'alternate_stripes' else 'relu'
    else:
        activation_fn = activation

    # Plot both partial and full contributions
    for data, contrib_type in [(partial_data, 'partial'), (full_data, 'full')]:
        # Compute percentiles along the first axis (runs)
        # Handle NaNs due to ragged padding
        if data is None or not hasattr(data, 'size') or data.size == 0 or data.shape[1] == 0:
            print(f"FOR contrib type {contrib_type} -> no data to plot")
            continue
        with np.errstate(all='ignore'):
            mean_last = np.nanmean(data[:, :, -1])
            std_last = np.nanstd(data[:, :, -1])
        print(f"FOR contrib type {contrib_type}")
        print(f"mean {mean_last}")
        print(f"std {std_last}")

        # Compute percentiles ignoring NaNs (ragged padding)
        p20 = np.nanpercentile(data, 20, axis=0)
        p50 = np.nanpercentile(data, 50, axis=0)
        p80 = np.nanpercentile(data, 80, axis=0)

        # Valid time steps: at least one non-NaN across runs for the first dim
        valid_mask = ~np.all(np.isnan(data[:, :, 0]), axis=0)
        x = np.arange(data.shape[1])[valid_mask]

        # Plot for each dimension (training/validation)
        D = p50.shape[1] if p50.ndim == 2 else 1
        for dim in range(D):
            label = f"h={activation_fn}, objective={contrib_type}" if dim == 0 else None
            plt.plot(
                x, p50[valid_mask, dim] if D > 1 else p50[valid_mask],
                color=colors[contrib_type],
                label=label,
                linestyle=linestyles[dim],
                linewidth=1.0
            )
            lower = p20[valid_mask, dim] if D > 1 else p20[valid_mask]
            upper = p80[valid_mask, dim] if D > 1 else p80[valid_mask]
            plt.fill_between(x, lower, upper, alpha=0.2, color=colors[contrib_type])

    # Increase font size for axis labels
    plt.xlabel('Time steps', fontsize=14)
    lambda_sum = "+".join([rf"\lambda_{{{i}}}" for i in range(1, k + 1)])
    plt.ylabel(rf"Proportion of explained variance for ${lambda_sum}$", fontsize=14)


    # Increase legend size
    plt.legend(loc='upper left', fontsize=12)

    # Increase tick label size
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()

    # NEW: save mirroring the results folder structure
    dataset_type_folder = "synthetic_data" if dataset in ["circles", "spheres", "alternate_stripes"] else "real_world_data"
    path_to_save = os.path.join(
        "..", "results", "plots", "training_curve", f"k={k}", f"activation={activation_fn}",
        f"{dataset}.pdf",
    )

    os.makedirs(os.path.dirname(path_to_save), exist_ok=True)
    plt.savefig(path_to_save, dpi=300)
    plt.close()

  plt.ylabel("Proportion of explained variance for $\lambda_1$", fontsize=14)


In [16]:
# Modified main loop
dataset_list = ["alternate_stripes",
                "circles",
                "spheres", 
                "wine", 
                "heart-statlog", 
                "ionosphere", 
                "breast_cancer", 
                "german_credit"]

dataset_list = ["heart-statlog"]
k = 1

for dataset in dataset_list:
    activation = "cos" if dataset == "alternate_stripes" else "relu"

    data_dictionary = read_all_nested_files(dataset, activation=activation, k=k)

    partial_contrib_data = [
        value for key, value in data_dictionary.items()
        if "partial_contrib=True" in key and key.endswith(os.path.sep + "results_list.p")
    ]

    full_contrib_data = [
        value for key, value in data_dictionary.items()
        if "partial_contrib=False" in key and key.endswith(os.path.sep + "results_list.p")
    ]

    print("WE ARE HERE")
    print(len(partial_contrib_data), len(full_contrib_data))

    # print(len(partial_contrib_data))
    
    explained_variance_partial = extract_explained_variance(partial_contrib_data)
    explained_variance_full = extract_explained_variance(full_contrib_data)
    
    # Plot both contributions on the same figure
    plot_percentiles(explained_variance_partial, explained_variance_full, dataset, k=k, activation=activation)

WE ARE HERE
15 15
FOR contrib type partial
mean 0.25032740498553685
std 0.031069413332638785
FOR contrib type full
mean 0.22127476913073468
std 0.040256273942448746
