In [None]:
from helper.plots import plot_phate_animations

# Embedding Trajectories
Compare embedding trajectories

In [None]:
dataset_name = 'mnist'

run_ids = [
#    "run-0011-CNN_mnist_32_0.9776",
#    "run-0012-CNN_mnist_32_0.9768"
#    "run-0007-CNN_mnist_128_0.9851",

#    "run-0016-CNN_cifar10_128_0.8093", # Seed 42, SAM, Residual
    "run-0017-CNN_cifar10_128_0.8072", # Seed 42, SAM
#    "run-0018-CNN_cifar10_128_0.8499", # Seed 42, Residual
    "run-0019-CNN_cifar10_128_0.8487", # Seed 42

#    "run-0020-CNN_cifar10_128_0.8079", # Seed 11, SAM, Residual
    "run-0021-CNN_cifar10_128_0.8054", # Seed 11, SAM
#    "run-0022-CNN_cifar10_128_0.8519", # Seed 11, Residual
    "run-0023-CNN_cifar10_128_0.8509", # Seed 11

#    "run-0024-CNN_cifar10_128_0.8062",
    "run-0025-CNN_cifar10_128_0.8062",
#    "run-0026-CNN_cifar10_128_0.8504",
    "run-0027-CNN_cifar10_128_0.8503",
]

In [None]:
titles = [
#    "Seed 42, SAM, Residual 0.8093",
    "Seed 42, SAM, 0.8072",
#    "Seed 42, SGD, Residual 0.8499",
    "Seed 42, SGD, 0.8487",
#    "Seed 11, SAM, Residual 0.8079",
    "Seed 11, SAM, 0.8054",
#    "Seed 11, SGD, Residual 0.8519",
    "Seed 11, SGD, 0.8509",
    
#    "Seed 6, SAM, Residual 0.0.8062",
    "Seed 6, SAM, 0.8062",
#    "Seed 6, SGD, Residual 0.8504",
    "Seed 6, SGD, 0.8503",
]

In [None]:
from helper.visualization import Run

runs = []
for run_id in run_ids:
    runs.append(Run(run_id, dataset_name))

### Trainings:

In [None]:
for run in runs:
    run.plot_training_records()

## PHATE Embedding Trajectories

In [None]:
from helper.visualization import mphate_on_runs

animations = mphate_on_runs(runs, titles)

In [None]:
from helper.plots import soft_smooth 

In [None]:
%matplotlib ipympl
%matplotlib widget

plot_phate_animations(animations, smooth_window=7, smooth_alpha=0.85, start_epoch=30)

In [None]:
denoised = animations[0].denoise(window_size=5, blend=0.7, do_embedding_drift=True, do_cka_similarities=False)

In [None]:
print(len(animations[0].projections))
animations[0].projections[0].shape

In [None]:
print(len(denoised.projections))
denoised.projections[0].shape

In [None]:
denoised.evaluate()

In [None]:
n = 0
print(f"{titles[n]}\n")
animations[n].evaluate()

In [None]:
list_drifts, list_cka = [], [] 

for i, run in enumerate(titles):
    print(f"{titles[i]}\n")
    drift, cka = animations[i].evaluate(verbose=False)
    list_drifts.append(drift); list_cka.append(cka)
    print(f"Mean Drift Similarity: {drift}, Mean Similarity to CKA: {cka}")

# Prediction Similarity

In [None]:
from helper.visualization import compute_prediction_similarities
from helper.plots import plot_prediction_similarity_heatmap

In [None]:
similarities = compute_prediction_similarities(runs, similarity="cosine")
plot_prediction_similarity_heatmap(similarities, run_titles=titles)

In [None]:
import matplotlib
matplotlib.pyplot.close()

In [None]:
similarities = compute_prediction_similarities(runs[::-2], similarity="cosine")
plot_prediction_similarity_heatmap(similarities, run_titles=titles[::-2], figsize=(6,5))

In [None]:
from helper.visualization import mphate_on_predictions

In [None]:
pred_animations = mphate_on_predictions(runs, titles=titles)

In [None]:
plot_phate_animations(pred_animations, start_epoch=5, legend_dist=-1.0)

# PHATE on Model weights

In [None]:
from helper.visualization import mphate_on_model_weights

animations_model = mphate_on_model_weights(runs, titles=titles)

In [None]:
# Optional: smooth for visualization only
plot_phate_animations(animations_model, smooth_window=5, smooth_alpha=0.9, start_epoch=0)

# Linear Mode Connectivity (LMC)

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [91]:
#from helper.neuro_viz import compute_lmc_loss_path
from helper.visualization import linear_mode_connectivity_path

def compute_LMC(runs, idx_1, idx_2, titles, model, dataset_name, device='cpu', num_points=10, ):
    print(f"Compute LMC between ({titles[idx_1]}) and ({titles[idx_2]})")

    info1 = runs[idx_1].results["model_info"]
    info2 = runs[idx_2].results["model_info"]
    
    assert info1 == info2, f"Model architectures differ:\n{info1}\n{info2}"
    print("Same architecture:")
    print(runs[idx_1].results["model_info"])

    assert repr(model_arch) == runs[idx_1].results["model_info"], "Wrong model"

    path = linear_mode_connectivity_path(runs[idx_1], runs[idx_2], num_points=num_points)

    losses = compute_lmc_loss_path(path, model, dataset_name, device=device)

    return losses

In [92]:
from helper.vision_classification import init_mlp_for_dataset, init_cnn_for_dataset

model_arch = init_cnn_for_dataset(dataset_name, conv_dims=[64, 128, 256], kernel_sizes=[5, 3, 3], hidden_dims=[256, 128], dropout=0.2, residual=False).to(device)

In [93]:
lmc_results = {}

In [94]:
for i in range(6):
    for j in range(6):
        if i == j:
            continue
        # skip if already computed in either order
        if (i, j) in lmc_results or (j, i) in lmc_results:
            continue
        print(f"Computing LMC for runs {i} → {j}")
        lmc_results[(i, j)] = compute_LMC(
            runs, i, j, titles, model_arch, dataset_name, device, num_points=3
        )
        break

Computing LMC for runs 0 → 1
Compute LMC between (Seed 42, SAM, 0.8072) and (Seed 42, SGD, 0.8487)
Same architecture:
CNN(conv_dims=[64, 128, 256], kernel_sizes=[5, 3, 3], hidden_dims=[256, 128], dropout=0.2, residual=False)


Compute LMC losses:  33%|███▎      | 1/3 [00:42<01:24, 42.19s/it]

tensor(2.3025)


Compute LMC losses:  33%|███▎      | 1/3 [00:53<01:46, 53.46s/it]


KeyboardInterrupt: 

In [None]:
import numpy as np
print(np.array(lmc_results[(0,1)]).min(), np.array(lmc_results[(0,1)]).max())

In [None]:
import matplotlib.pyplot as plt

# lmc_results is a dict mapping (i,j) → list of losses
fig, ax = plt.subplots(figsize=(10, 6))

for (i, j), losses in lmc_results.items():
    ax.plot(
        losses,
        label=f"{titles[i]} ↔ {titles[j]}",
        alpha=0.8
    )

ax.set_xlabel("Interpolation index")
ax.set_ylabel("Loss")
ax.set_title("LMC loss paths for all run‐pairs")
ax.set_ylim(bottom=0)
ax.legend(
    title="Run pairs",
    loc="upper right",
    bbox_to_anchor=(1.3, 1),
    ncol=1,
    frameon=False
)
plt.tight_layout()
plt.show()

In [87]:
for i, run in enumerate(runs):
    print(titles[i], ": ", run.results["val_losses"][-1])

Seed 42, SAM, 0.8072 :  0.58981925
Seed 42, SGD, 0.8487 :  0.46490496
Seed 11, SAM, 0.8054 :  0.59002
Seed 11, SGD, 0.8509 :  0.46150622
Seed 6, SAM, 0.8062 :  0.5910221
Seed 6, SGD, 0.8503 :  0.45945826


In [90]:
import torch
from copy import deepcopy
from tqdm import tqdm
from helper.neuro_viz import Loss, repopulate_model_fixed

def compute_lmc_loss_path(
    weight_path: list[np.ndarray],
    model: torch.nn.Module,
    dataset_name: str,
    device: str = 'cpu',
    loss_name: str = 'test_loss',
    whichloss: str = 'crossentropy',
    bn_recal_batches: int = 100
):
    loss_obj = Loss(dataset_name, device)
    losses = []
    model.to(device)

    for flat in tqdm(weight_path, desc="Compute LMC losses"):
        weights = torch.as_tensor(flat, dtype=torch.float32, device=device)
        with torch.no_grad():
            model = repopulate_model_fixed(weights.clone(), model)
        loss = loss_obj.get_loss(model, loss_name, whichloss).detach()
        print(loss)
        losses.append(loss.item())

    return losses