In [None]:
from helper.plots import plot_phate_animations

# Embedding Trajectories
Compare embedding trajectories

In [None]:
dataset_name = 'cifar10'

run_ids = [
#    "run-0011-CNN_mnist_32_0.9776",
#    "run-0012-CNN_mnist_32_0.9768"
#    "run-0007-CNN_mnist_128_0.9851",

    "run-0016-CNN_cifar10_128_0.8093", # Seed 42, SAM, Residual
    "run-0017-CNN_cifar10_128_0.8072", # Seed 42, SAM
    "run-0018-CNN_cifar10_128_0.8499", # Seed 42, Residual
    "run-0019-CNN_cifar10_128_0.8487", # Seed 42

    "run-0020-CNN_cifar10_128_0.8079", # Seed 11, SAM, Residual
    "run-0021-CNN_cifar10_128_0.8054", # Seed 11, SAM
    "run-0022-CNN_cifar10_128_0.8519", # Seed 11, Residual
    "run-0023-CNN_cifar10_128_0.8509", # Seed 11

    "run-0024-CNN_cifar10_128_0.8062",
    "run-0025-CNN_cifar10_128_0.8062",
    "run-0026-CNN_cifar10_128_0.8504",
    "run-0027-CNN_cifar10_128_0.8503",
]

In [None]:
titles = [
    "Seed 42, SAM, Residual 0.8093",
    "Seed 42, SAM, 0.8072",
    "Seed 42, SGD, Residual 0.8499",
    "Seed 42, SGD, 0.8487",

    "Seed 11, SAM, Residual 0.8079",
    "Seed 11, SAM, 0.8054",
    "Seed 11, SGD, Residual 0.8519",
    "Seed 11, SGD, 0.8509",
    
    "Seed 6, SAM, Residual 0.0.8062",
    "Seed 6, SAM, 0.8062",
    "Seed 6, SGD, Residual 0.8504",
    "Seed 6, SGD, 0.8503",
]

In [None]:
from helper.visualization import Run

runs = []
for run_id in run_ids:
    runs.append(Run(run_id, dataset_name))

### Trainings:

In [None]:
for run in runs:
    run.plot_training_records()

## PHATE Embedding Trajectories

In [None]:
from helper.visualization import mphate_on_runs

animations = mphate_on_runs(runs, titles)

In [None]:
%matplotlib ipympl
%matplotlib widget

plot_phate_animations(animations, smooth_window=7, smooth_alpha=0.85, start_epoch=30)

In [None]:
denoised = animations[0].denoise(window_size=5, blend=0.7, do_embedding_drift=True, do_cka_similarities=False)

In [None]:
print(len(animations[0].projections))
animations[0].projections[0].shape

In [None]:
print(len(denoised.projections))
denoised.projections[0].shape

In [None]:
denoised.evaluate(is_trajectory=True)

In [None]:
n = 0
print(f"{titles[n]}\n")
animations[n].evaluate()

In [None]:
list_drifts, list_cka = [], [] 

for i, run in enumerate(titles):
    print(f"{titles[i]}\n")
    drift, cka = animations[i].evaluate(verbose=False)
    list_drifts.append(drift); list_cka.append(cka)
    print(f"Mean Drift Similarity: {drift}, Mean Similarity to CKA: {cka}")

# Embedding Space Similarities

In [None]:
from helper.visualization import compute_cka
from tqdm.notebook import tqdm

def embedding_space_similarities(
    runs: list,
    method: str = 'cka'
) -> list[np.ndarray]:
    """
    For each epoch, compute an n_runs × n_runs similarity matrix between the embedding
    spaces of each pair of runs.

    Args:
        runs: list of Run objects, each with .embeddings: List[np.ndarray] of shape (n_points, emb_dim)
        method: 'cka' | 'cosine' | 'euclidean' | 'manhattan'
    Returns:
        sims: List of length n_epochs, each an (n_runs x n_runs) array of similarities in [0,1].
    """
    n_runs = len(runs)
    n_epochs = len(runs[0].embeddings)
    sims = []

    for ep in tqdm(range(n_epochs)[::5]):
        M = np.zeros((n_runs, n_runs), dtype=float)
        for i in range(n_runs):
            Xi = runs[i].embeddings[ep]
            for j in range(i, n_runs):
                Yj = runs[j].embeddings[ep]
                if method == 'cka':
                    s = compute_cka(Xi, Yj)
                else:
                    # flatten to vectors
                    v1 = Xi.reshape(-1)
                    v2 = Yj.reshape(-1)
                    if method == 'cosine':
                        s = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
                    elif method in ('euclidean', 'manhattan'):
                        p = 2 if method == 'euclidean' else 1
                        d = np.linalg.norm(v1 - v2, ord=p)
                        s = 1.0 / (1.0 + d)
                    else:
                        raise ValueError(f"Unknown method {method!r}")
                M[i, j] = M[j, i] = s
        sims.append(M)

    return sims

In [None]:
emb_sims = embedding_space_similarities(runs, method='cka')

In [None]:
plot_prediction_similarity_heatmap(
    similarities=emb_sims,
    run_titles=titles,
    cmap='viridis',
    figsize=(8, 6),
    title="Embedding Space Similarity"
)

# Prediction Similarity

In [None]:
from helper.visualization import compute_prediction_similarities
from helper.plots import plot_prediction_similarity_heatmap

In [None]:
similarities = compute_prediction_similarities(runs, similarity="cosine")
plot_prediction_similarity_heatmap(similarities, run_titles=titles, figsize=(11, 7))

In [None]:
import matplotlib
matplotlib.pyplot.close()

In [None]:
similarities = compute_prediction_similarities(runs[::-2], similarity="cosine")
plot_prediction_similarity_heatmap(similarities, run_titles=titles[::-2], figsize=(6,5))

In [None]:
from helper.visualization import mphate_on_predictions

In [None]:
pred_animations = mphate_on_predictions(runs, titles=titles)

In [None]:
plot_phate_animations(pred_animations, start_epoch=5, legend_dist=-1.0)

# PHATE on Model weights

In [None]:
from helper.visualization import mphate_on_model_weights

animations_model = mphate_on_model_weights(runs, titles=titles)

In [None]:
# Optional: smooth for visualization only
plot_phate_animations(animations_model, smooth_window=5, smooth_alpha=0.9, start_epoch=0)

### Weight similarities

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity, pairwise_distances

def model_weight_similarities(runs: list, metric: str = 'cosine'):
    """
    Compute epoch‐wise similarity matrices between runs in weight‐space.

    Args:
        runs: list of Run, each with get_flattened_weights() -> np.ndarray, shape (n_epochs, n_features)
        metric: one of 'cosine', 'euclidean', 'manhattan'

    Returns:
        similarities: list of length n_epochs, each an (n_runs x n_runs) np.ndarray of similarities in (0,1]
    """
    # stack to shape (n_runs, n_epochs, n_features)
    all_flat = np.stack([run.get_flattened_weights() for run in runs], axis=0)
    n_runs, n_epochs, _ = all_flat.shape

    sims = []
    for ep in range(n_epochs):
        X = all_flat[:, ep, :]  # shape (n_runs, n_features)

        if metric == 'cosine':
            S = cosine_similarity(X)
        elif metric in ('euclidean', 'manhattan'):
            D = pairwise_distances(X, metric=metric)
            S = D
        else:
            raise ValueError(f"Unsupported metric: {metric!r}")

        sims.append(S)

    return sims

In [None]:
# choose your metric:
weight_sims = model_weight_similarities(runs, metric='cosine')

In [None]:
# then plot
plot_prediction_similarity_heatmap(
    similarities=[sim[::2,::2] for sim in weight_sims],
    run_titles=titles[::2],
    cmap='viridis', # 'viridis_r' or 'viridis'
    figsize=(7, 4)
)

# then plot
plot_prediction_similarity_heatmap(
    similarities=[sim[::-2,::-2] for sim in weight_sims],
    run_titles=titles[::-2],
    cmap='viridis', # 'viridis_r' or 'viridis'
    figsize=(7, 4)
)

# Linear Mode Connectivity (LMC)

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#from helper.neuro_viz import compute_lmc_loss_path
from helper.visualization import linear_mode_connectivity_path

def compute_LMC(runs, idx_1, idx_2, titles, model, dataset_name, device='cpu', num_points=10, ):
    print(f"Compute LMC between ({titles[idx_1]}) and ({titles[idx_2]})")

    info1 = runs[idx_1].results["model_info"]
    info2 = runs[idx_2].results["model_info"]
    
    assert info1 == info2, f"Model architectures differ:\n{info1}\n{info2}"
    print("Same architecture:")
    print(runs[idx_1].results["model_info"])

    assert repr(model_arch) == runs[idx_1].results["model_info"], "Wrong model"

    path = linear_mode_connectivity_path(runs[idx_1], runs[idx_2], num_points=num_points)

    losses = compute_lmc_loss_path(path, model, dataset_name, device=device)

    return losses

In [None]:
from helper.vision_classification import init_mlp_for_dataset, init_cnn_for_dataset

model = init_cnn_for_dataset(dataset_name, conv_dims=[64, 128, 256], kernel_sizes=[5, 3, 3], hidden_dims=[256, 128], dropout=0.2, residual=False).to(device)

In [None]:
lmc_results = {}

In [None]:
for i in range(6):
    for j in range(6):
        if i == j:
            continue
        # skip if already computed in either order
        if (i, j) in lmc_results or (j, i) in lmc_results:
            continue
        print(f"Computing LMC for runs {i} → {j}")
        lmc_results[(i, j)] = compute_LMC(
            runs, i, j, titles, model, dataset_name, device, num_points=3
        )
        break

In [None]:
import numpy as np
print(np.array(lmc_results[(0,1)]).min(), np.array(lmc_results[(0,1)]).max())

In [None]:
import matplotlib.pyplot as plt

# lmc_results is a dict mapping (i,j) → list of losses
fig, ax = plt.subplots(figsize=(10, 6))

for (i, j), losses in lmc_results.items():
    ax.plot(
        losses,
        label=f"{titles[i]} ↔ {titles[j]}",
        alpha=0.8
    )

ax.set_xlabel("Interpolation index")
ax.set_ylabel("Loss")
ax.set_title("LMC loss paths for all run‐pairs")
ax.set_ylim(bottom=0)
ax.legend(
    title="Run pairs",
    loc="upper right",
    bbox_to_anchor=(1.3, 1),
    ncol=1,
    frameon=False
)
plt.tight_layout()
plt.show()

In [None]:
for i, run in enumerate(runs):
    print(titles[i], ": ", run.results["val_losses"][-1])

In [None]:
import torch
from copy import deepcopy
from tqdm import tqdm
from helper.neuro_viz import Loss, repopulate_model_fixed

def compute_lmc_loss_path(
    weight_path: list[np.ndarray],
    model: torch.nn.Module,
    dataset_name: str,
    device: str = 'cpu',
    loss_name: str = 'test_loss',
    whichloss: str = 'crossentropy',
    bn_recal_batches: int = 100
):
    loss_obj = Loss(dataset_name, device)
    losses = []
    model.to(device)

    for flat in tqdm(weight_path, desc="Compute LMC losses"):
        weights = torch.as_tensor(flat, dtype=torch.float32, device=device)
        with torch.no_grad():
            model = repopulate_model_fixed(weights.clone(), model)
        loss = loss_obj.get_loss(model, loss_name, whichloss).detach()
        print(loss)
        losses.append(loss.item())

    return losses

In [None]:
flat = runs[0].get_flattened_weights()[0]
weights = torch.as_tensor(flat, device=device)
loss_obj = Loss(dataset_name, device)
with torch.no_grad():
    model = repopulate_model_fixed(weights.clone(), model)
loss = loss_obj.get_loss(model, 'test_loss', 'crossentropy').detach()
print(loss)

In [None]:
flat[:10]