In [9]:
import pickle
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

import os
from typing import Tuple
from pydantic import BaseModel, Field
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

In [10]:
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# load results

In [11]:
# pydantic model

class PreprocessingSteps(BaseModel):
    transformers: Tuple[
        Tuple[PCA, StandardScaler, np.ndarray, np.ndarray],
        Tuple[float, float]
    ] = Field(
        ..., 
        description="Preprocessing steps including PCA, StandardScaler, and arrays"
    )

    class Config:
        arbitrary_types_allowed = True
        json_encoders = {
            np.ndarray: lambda v: v.tolist(),
            PCA: lambda v: str(v),
            StandardScaler: lambda v: str(v)
        }


In [12]:
def load_data(dataset: str, activation: str = "relu") -> dict:
    """
    Load data from pickle files for a specific dataset.
    
    Args:
        dataset (str): Name of the dataset or specific subdirectory
    
    Returns:
        dict: A dictionary with full file paths as keys and loaded data as values
    """
    
    base_dir = "../results/datasets/real_world_data/"
    if dataset in ["circles", "spheres", "alternate_stripes"]:
        base_dir = "../results/datasets/synthetic_data"
    
    full_path = os.path.join(base_dir, dataset)
    full_path = full_path + f"/activation={activation}"

    print(full_path)
    
    data_dict = {}
    
    # Walk through the directory tree starting from the specified dataset path
    for root, dirs, files in os.walk(full_path):
        for file in files:
            # Check if the file is a pickle file
            if file.endswith('.pkl') or file.endswith('.p'):
                full_file_path = os.path.join(root, file)
                
                try:
                    # Load the pickle file
                    with open(full_file_path, "rb") as f:
                        data = pickle.load(f)
                    
                    # Store in the dictionary with full path as key
                    data_dict[full_file_path] = data
                
                except (IOError, pickle.UnpicklingError) as e:
                    print(f"Error loading {full_file_path}: {e}")
    
    return data_dict

In [13]:
def read_all_nested_files(dataset: str, activation: str = "relu", k: int = 2):
    base_dir = "../results/datasets/real_world_data/"
    if dataset in ["circles", "spheres", "alternate_stripes"]:
        base_dir = "../results/datasets/synthetic_data"

    # NEW STRUCTURE: .../<dataset>/k=<k>/activation=<activation>/...
    full_path = os.path.join(base_dir, dataset, f"k={k}", f"activation={activation}")

    file_contents = {}
    for root, dirs, files in os.walk(full_path):
        for filename in files:
            file_path = os.path.join(root, filename)

            # For training curves, results_list.p is the one you want
            if os.path.basename(file_path) == "results_list.p":
                try:
                    with open(file_path, "rb") as f:
                        file_contents[file_path] = pickle.load(f)
                except Exception as e:
                    print(f"Could not read {file_path}: {e}")

    return file_contents


# functions to collate the data

In [14]:
# get the explained variance objectives for train and validation

def extract_explained_variance(data: list[list[PreprocessingSteps]]) -> np.ndarray:
    """
    Extracts the explained variance curves from a list of runs and returns a
    padded array with NaNs for variable-length runs.

    Input structure expectation (per run):
    - data[j] is an iterable of time steps for run j
    - each element has the explained variance as the last item (index -1)
      and is typically a 1D array-like of length 2: [train, val]

    Returns:
    - np.ndarray with shape (n_runs, max_time, d), padded with np.nan
      so that ragged runs are handled consistently downstream.
    """

    runs: list[np.ndarray] = []
    max_len = 0

    for run in data:
        # Collect per-step explained variance for this run
        seq = []
        for step in range(len(run)):
            try:
                last = run[step][-1]
            except Exception:
                # If structure is unexpected, skip this step
                continue

            arr = np.array(last)
            # Ensure 1D shape
            if arr.ndim == 0:
                arr = np.array([float(arr)])
            elif arr.ndim > 1:
                arr = arr.ravel()

            seq.append(arr.astype(float, copy=False))

        if len(seq) == 0:
            continue

        seq_arr = np.vstack([s if s.ndim == 1 else s.ravel() for s in seq])
        runs.append(seq_arr)
        max_len = max(max_len, seq_arr.shape[0])

    if len(runs) == 0:
        return np.empty((0, 0, 0))

    # Infer feature dimension (e.g., 2 for [train, val])
    d = runs[0].shape[1] if runs[0].ndim == 2 else 1

    padded = np.full((len(runs), max_len, d), np.nan, dtype=float)
    for i, r in enumerate(runs):
        T = r.shape[0]
        if r.ndim == 1:
            padded[i, :T, 0] = r
        else:
            dd = min(d, r.shape[1])
            padded[i, :T, :dd] = r[:, :dd]

    return padded

# visualise

In [15]:
def plot_percentiles(
        partial_data_by_k: dict[int, np.ndarray],
        full_data_by_k: dict[int, np.ndarray],
        dataset: str,
        ks: tuple[int, ...] = (1, 2),
        activation: str | None = None,
        ax: plt.Axes | None = None,
        y_n_ticks: int = 5,
):
    if ax is None:
        _, ax = plt.subplots(figsize=(10, 6))

    # Colors for different contribution types
    colors = {"partial": "darkorange", "full": "darkblue"}
    linestyles = {1: "--", 2: "-"}  # dashed for k=1, solid for k=2

    if activation is None:
        activation_fn = "cos" if dataset == "alternate_stripes" else "relu"
    else:
        activation_fn = activation

    for k in ks:
        for data_by_k, contrib_type in [(partial_data_by_k, "partial"), (full_data_by_k, "full")]:
            data = data_by_k.get(k, None)

            if data is None or not hasattr(data, "size") or data.size == 0 or data.shape[1] == 0:
                print(f"FOR k={k}, contrib type {contrib_type} -> no data to plot")
                continue

            # Compute percentiles ignoring NaNs
            p20 = np.nanpercentile(data, 20, axis=0)
            p50 = np.nanpercentile(data, 50, axis=0)
            p80 = np.nanpercentile(data, 80, axis=0)

            valid_mask = ~np.all(np.isnan(data[:, :, 0]), axis=0)
            x = np.arange(data.shape[1])[valid_mask]

            D = p50.shape[1] if p50.ndim == 2 else 1
            val_dim = 1 if D > 1 else 0

            y = p50[valid_mask, val_dim] if D > 1 else p50[valid_mask]
            ax.plot(
                x,
                y,
                color=colors[contrib_type],
                linestyle=linestyles.get(k, "-"),
                linewidth=2.2,
            )

            lower = p20[valid_mask, val_dim] if D > 1 else p20[valid_mask]
            upper = p80[valid_mask, val_dim] if D > 1 else p80[valid_mask]
            ax.fill_between(x, lower, upper, alpha=0.3, color=colors[contrib_type])

    ax.set_title(dataset, fontsize=14)
    ax.set_xlabel("Generations", fontsize=12)
    ax.set_ylabel(r"Proportion of explained variance", fontsize=12)
    ax.tick_params(axis="both", labelsize=11)
    ax.grid(True, linestyle="--", alpha=0.5)

    ax.yaxis.set_major_locator(MaxNLocator(nbins=y_n_ticks))

    return activation_fn

# -------------------------------
# Main plotting workflow
# -------------------------------
dataset_list = [
    "alternate_stripes",
    "circles",
    "spheres",
    "wine",
    "heart-statlog",
    "ionosphere",
    "breast_cancer",
    "german_credit",
]

ks_to_plot = (1, 2)
y_n_ticks = 5  # number of ticks per subplot

fig, axes = plt.subplots(4, 2, figsize=(14, 18), sharey=False)
axes = axes.ravel()
fig.subplots_adjust(hspace=0.35)


for i, dataset in enumerate(dataset_list):
    ax = axes[i]
    activation = "cos" if dataset == "alternate_stripes" else "relu"

    ks_for_dataset = (1,) if dataset == "spheres" else ks_to_plot

    partial_by_k: dict[int, np.ndarray] = {}
    full_by_k: dict[int, np.ndarray] = {}

    # -------------------------------
    # Load data for each k
    # -------------------------------
    for k in ks_for_dataset:
        data_dictionary = read_all_nested_files(dataset, activation=activation, k=k)

        partial_contrib_data = [
            value for key, value in data_dictionary.items()
            if "partial_contrib=True" in key and key.endswith(os.path.sep + "results_list.p")
        ]

        full_contrib_data = [
            value for key, value in data_dictionary.items()
            if "partial_contrib=False" in key and key.endswith(os.path.sep + "results_list.p")
        ]

        partial_by_k[k] = extract_explained_variance(partial_contrib_data)
        full_by_k[k] = extract_explained_variance(full_contrib_data)

        # -------------------------------
        # Compute and print mean/std for last generation
        # -------------------------------
        for contrib_type, data in [("Partial", partial_by_k[k]), ("Global", full_by_k[k])]:
            if data is None or data.size == 0:
                continue
            last_gen_idx = data.shape[1] - 1
            last_gen_data = data[:, last_gen_idx, ...]
            mean_val = np.nanmean(last_gen_data)
            std_val = np.nanstd(last_gen_data)
            print(
                f"Dataset={dataset}, k={k}, contrib={contrib_type}, "
                f"last generation mean={mean_val:.4f}, std={std_val:.4f}"
            )

    # -------------------------------
    # Plot percentiles
    # -------------------------------
    plot_percentiles(
        partial_by_k,
        full_by_k,
        dataset,
        ks=ks_for_dataset,
        activation=activation,
        ax=ax,
        y_n_ticks=y_n_ticks,
    )

# -------------------------------
# Adjust y-axis: fixed number of ticks per subplot
# -------------------------------
for ax in axes:
    ymin, ymax = ax.get_ylim()
    ticks = np.linspace(ymin, ymax, y_n_ticks)  # evenly spaced including bottom/top
    ticks = [np.round(val, 2) for val in ticks]
    ax.set_yticks(ticks)

# -------------------------------
# Create beautiful custom legend
# -------------------------------
from matplotlib.lines import Line2D

# Define colors and linestyles
colors = {"Partial": "darkorange", "Full": "darkblue"}
linestyles = {"k=1": "--", "k=2": "-"}

# Create custom legend handles
legend_elements = []
for contrib_type in ["Partial", "Full"]:
    for k_label in ["k=1", "k=2"]:
        legend_elements.append(
            Line2D([0], [0],
                   color=colors[contrib_type],
                   linestyle=linestyles[k_label],
                   linewidth=2.2,
                   label=f"{contrib_type}, {k_label}")
        )

# Add legend above the plots in 2x2 layout
fig.legend(
    handles=legend_elements,
    loc="upper center",
    ncol=2,
    fontsize=11,
    frameon=True,
    fancybox=True,
    shadow=True,
    bbox_to_anchor=(0.5, 0.99)
)

fig.tight_layout(rect=[0, 0, 1, 0.94])
fig.savefig("multi_dataset_plot.png", dpi=300, bbox_inches='tight')
plt.close(fig)


Dataset=alternate_stripes, k=1, contrib=Partial, last generation mean=0.6934, std=0.1035
Dataset=alternate_stripes, k=1, contrib=Global, last generation mean=0.7150, std=0.0730
FOR k=2, contrib type partial -> no data to plot
FOR k=2, contrib type full -> no data to plot
Dataset=circles, k=1, contrib=Partial, last generation mean=0.7356, std=0.0775
Dataset=circles, k=1, contrib=Global, last generation mean=0.7141, std=0.0734
FOR k=2, contrib type partial -> no data to plot
FOR k=2, contrib type full -> no data to plot
Dataset=spheres, k=1, contrib=Partial, last generation mean=0.5956, std=0.0208
Dataset=spheres, k=1, contrib=Global, last generation mean=0.5174, std=0.0391
Dataset=wine, k=1, contrib=Partial, last generation mean=0.4478, std=0.0279
Dataset=wine, k=1, contrib=Global, last generation mean=0.3635, std=0.0345
Dataset=wine, k=2, contrib=Partial, last generation mean=0.6203, std=0.0216
Dataset=wine, k=2, contrib=Global, last generation mean=0.5436, std=0.0279
Dataset=heart-sta