In [None]:
from helper.plots import plot_phate_animations

# Embedding Trajectories
Compare embedding trajectories

In [None]:
dataset_name = 'mnist'

run_ids = [
#    "run-0011-CNN_mnist_32_0.9776",
#    "run-0012-CNN_mnist_32_0.9768"
#    "run-0007-CNN_mnist_128_0.9851",

    "run-0016-CNN_cifar10_128_0.8093", # Seed 42, SAM, Residual
    "run-0017-CNN_cifar10_128_0.8072", # Seed 42, SAM
    "run-0018-CNN_cifar10_128_0.8499", # Seed 42, Residual
    "run-0019-CNN_cifar10_128_0.8487", # Seed 42

    "run-0020-CNN_cifar10_128_0.8079", # Seed 11, SAM, Residual
    "run-0021-CNN_cifar10_128_0.8054", # Seed 11, SAM
    "run-0022-CNN_cifar10_128_0.8519", # Seed 11, Residual
    "run-0023-CNN_cifar10_128_0.8509", # Seed 11

    "run-0024-CNN_cifar10_128_0.8062",
    "run-0025-CNN_cifar10_128_0.8062",
    "run-0026-CNN_cifar10_128_0.8504",
    "run-0027-CNN_cifar10_128_0.8503",
]

In [None]:
titles = [
    "Seed 42, SAM, Residual 0.8093",
    "Seed 42, SAM, 0.8072",
    "Seed 42, SGD, Residual 0.8499",
    "Seed 42, SGD, 0.8487",
    "Seed 11, SAM, Residual 0.8079",
    "Seed 11, SAM, 0.8054",
    "Seed 11, SGD, Residual 0.8519",
    "Seed 11, SGD, 0.8509",
    
    "Seed 6, SAM, Residual 0.0.8062",
    "Seed 6, SAM, 0.8062",
    "Seed 6, SGD, Residual 0.8504",
    "Seed 6, SGD, 0.8503",
]

In [None]:
from helper.visualization import Run

runs = []
for run_id in run_ids:
    runs.append(Run(run_id, dataset_name))

### Trainings:

In [None]:
for run in runs:
    run.plot_training_records()

## PHATE Embedding Trajectories

In [None]:
from helper.visualization import mphate_on_runs

animations = mphate_on_runs(runs, titles)

In [None]:
from helper.plots import soft_smooth 

In [None]:
%matplotlib ipympl
%matplotlib widget

plot_phate_animations(animations, smooth_window=7, smooth_alpha=0.85, start_epoch=30)

In [None]:
n = 4
print(f"{titles[n]}\n")
animations[n].evaluate()

# Prediction Similarity

In [None]:
from helper.visualization import compute_prediction_similarities
from helper.plots import plot_prediction_similarity_heatmap

In [None]:
similarities = compute_prediction_similarities(runs, similarity="cosine")
plot_prediction_similarity_heatmap(similarities, run_titles=titles)

In [None]:
import matplotlib
matplotlib.pyplot.close()

In [None]:
similarities = compute_prediction_similarities(runs[::-2], similarity="cosine")
plot_prediction_similarity_heatmap(similarities, run_titles=titles[::-2], figsize=(6,5))

In [None]:
from helper.visualization import mphate_on_predictions

In [None]:
pred_animations = mphate_on_predictions(runs, titles=titles)

In [None]:
plot_phate_animations(pred_animations, start_epoch=5, legend_dist=-1.0)

# PHATE on Model weights

In [None]:
from NeuroVisualizer.neuro_aux.utils import get_files

run_ids = []
pt_files = []

for run in runs:
    run_ids.append(run.results["ll_flattened_weights_dir"])

for run_id in run_ids:
    model_folder = f"trainings/{run_id}"
    pt_files.append(get_files(model_folder, prefix="model-"))
    print(f"Found {len(pt_files[-1])} checkpoint files.")

pt_files_flat = [path for sublist in pt_files for path in sublist]

In [None]:
import torch
import numpy as np

from tqdm import tqdm

def load_flattened_weights(pt_file_paths, device="cpu"):
    """
    Load already-flattened model weights (using weights_only=True) from .pt files.
    """
    flattened = []
    for path in tqdm(pt_file_paths, desc="Loading model checkpoints"):
        try:
            tensor = torch.load(path, map_location=device, weights_only=True)
        except TypeError:
            raise ValueError(f"torch.load(..., weights_only=True) is not supported for {path}.")

        if not isinstance(tensor, torch.Tensor):
            raise ValueError(f"Expected a flattened tensor in {path}, got {type(tensor)}")

        flattened.append(tensor.detach().cpu().numpy())

    return np.stack(flattened)  # shape: (n_checkpoints, total_weights)

In [None]:
import m_phate
from helper.visualization import Animation

def mphate_on_model_weights(pt_files_by_run, runs, titles=None):
    """
    Apply M-PHATE to flattened model weights from multiple runs.

    Args:
        pt_files_by_run: list of list of checkpoint paths (one list per run)
        titles: optional list of labels per run

    Returns:
        animations: list of Animation objects, one per run
    """
    all_run_flattened = []

    for file_list in pt_files_by_run:
        run_flattened = load_flattened_weights(file_list)  # shape: (epochs, features)
        all_run_flattened.append(run_flattened)

    all_run_flattened = np.stack(all_run_flattened)  # shape: (n_runs, n_epochs, features)
    combined_weights = np.transpose(all_run_flattened, (1, 0, 2))  # shape: (epochs, runs, features)

    mphate_op = m_phate.M_PHATE(knn_dist="cosine", mds_dist="cosine")
    mphate_emb = mphate_op.fit_transform(combined_weights)  # shape: (epochs * runs, 2)

    n_epochs, n_runs = combined_weights.shape[:2]
    mphate_emb = mphate_emb.reshape(n_epochs, n_runs, 2)
    mphate_trajectories = np.transpose(mphate_emb, (1, 0, 2))  # (runs, epochs, 2)

    animations = []
    for idx, file_list in enumerate(pt_files_by_run):
        title = titles[idx] if titles else f"Run {idx}"
        anim = Animation(
            projections=mphate_trajectories[idx],
            title=title,
            run=runs[idx]
        )
        animations.append(anim)

    return animations


In [None]:
# pt_files_by_run: list of lists of checkpoint paths
animations_model = mphate_on_model_weights(pt_files, runs, titles=titles)

In [None]:
# Optional: smooth for visualization only
plot_phate_animations(animations_model, smooth_window=5, smooth_alpha=0.9)