In [None]:
from helper.plots import plot_phate_animations

# Compare Runs
Compare embedding trajectories

In [None]:
dataset_name = 'cifar10'

run_ids = [
    "run-0017-CNN_cifar10_128_0.8072", # Seed 42, SAM
    "run-0019-CNN_cifar10_128_0.8487", # Seed 42
    "run-0021-CNN_cifar10_128_0.8054", # Seed 11, SAM
    "run-0023-CNN_cifar10_128_0.8509", # Seed 11
    "run-0025-CNN_cifar10_128_0.8062",
    "run-0027-CNN_cifar10_128_0.8503"
]

titles = [    
    "Seed 42, SAM, 0.8072",
    "Seed 42, SGD, 0.8487",
    "Seed 11, SAM, 0.8054",
    "Seed 11, SGD, 0.8509",
    "Seed 6, SAM, 0.8062",
    "Seed 6, SGD, 0.8503",
]

In [None]:
dataset_name = 'cifar10'

run_ids = [
    "run-0016-CNN_cifar10_128_0.8093", # Seed 42, SAM, Residual
    "run-0018-CNN_cifar10_128_0.8499", # Seed 42, Residual
    "run-0020-CNN_cifar10_128_0.8079", # Seed 11, SAM, Residual
    "run-0022-CNN_cifar10_128_0.8519", # Seed 11, Residual
    "run-0024-CNN_cifar10_128_0.8062",
    "run-0026-CNN_cifar10_128_0.8504"
]

titles = [
    "Seed 42, SAM, Residual 0.8093",
    "Seed 42, SGD, Residual 0.8499",
    "Seed 11, SAM, Residual 0.8079",
    "Seed 11, SGD, Residual 0.8519",
    "Seed 6, SAM, Residual 0.0.8062",
    "Seed 6, SGD, Residual 0.8504",
]

In [None]:
dataset_name = 'mnist'

run_ids = [
    "run-0011-CNN_mnist_32_0.9776",
    "run-0012-CNN_mnist_32_0.9768",
    "run-0013-CNN_mnist_32_0.9797",
    "run-0014-CNN_mnist_32_0.9744",
]

titles = [
    "Seed 42, SGD, 0.9776",
    "Seed 42, SAM, 0.9768",
    "Seed 42, SGD, Residual 0.9797",
    "Seed 42, SAM, Residual 0.9744",
]

In [None]:
from helper.visualization import Run

runs = []
for run_id in run_ids:
    runs.append(Run(run_id, dataset_name))

### Trainings:

In [None]:
for run in runs[:2]:
    run.plot_training_records()

## PHATE Embedding Trajectories

In [None]:
from helper.visualization import mphate_on_runs

animations = mphate_on_runs(runs, titles)

In [None]:
from helper.plots import soft_smooth 

In [None]:
%matplotlib ipympl
%matplotlib widget

plot_phate_animations(animations, smooth_window=7, smooth_alpha=0.85, start_epoch=30)

## CKA Similarities

### Multiple paths per epoch

In [None]:
# Epoch-wise CKA of embedding spaces
from helper.visualization import compute_epochwise_embedding_cka
from helper.plots import plot_prediction_similarity_heatmap

cka_mats = compute_epochwise_embedding_cka(
    runs,
    skip=1 # Less Epochs, faster computation
)
plot_prediction_similarity_heatmap(
    cka_mats,
    run_titles=titles,
    cmap="magma",
    title="Embedding CKA at Epoch {}"
)

### Compare two trainings
Includes time dimension

In [None]:
from helper.visualization import compute_cross_epoch_similarity
from helper.plots import plot_cross_epoch_similarity_heatmap
i, j = 0, 1

# 1) Compute
S, ix, iy = compute_cross_epoch_similarity(
    runs[i], runs[j],
    mode="embeddings", similarity="cka",
    skip=10,
    desc_prefix=f"{titles[i]} vs {titles[j]} — "
)

In [None]:
# 2) Plot
plot_cross_epoch_similarity_heatmap(
    S, ix, iy,
    title_x=titles[i], title_y=titles[j],
    similarity="cka", cmap="magma", figsize=(7.5, 6.5),
    show_percent=True,
    vmin_from_data=True,
    extra_title=" - SAM vs SGD on same Seed"
)

In [None]:
i, j = 0, 3

# 1) Compute
S, ix, iy = compute_cross_epoch_similarity(
    runs[i], runs[j],
    mode="embeddings", similarity="cka",
    skip=10,
    desc_prefix=f"{titles[i]} vs {titles[j]} — "
)

In [None]:
# 2) Plot
plot_cross_epoch_similarity_heatmap(
    S, ix, iy,
    title_x=titles[i], title_y=titles[j],
    similarity="cka", cmap="magma", figsize=(7.5, 6.5),
    show_percent=True,
    vmin_from_data=True,
    extra_title=" - SAM vs SGD on different Seeds"
)

In [None]:
i, j = 0, 2

# 1) Compute
S, ix, iy = compute_cross_epoch_similarity(
    runs[i], runs[j],
    mode="embeddings", similarity="cka",
    skip=10, #start_epoch=40,
    desc_prefix=f"{titles[i]} vs {titles[j]} — "
)

# 2) Plot
plot_cross_epoch_similarity_heatmap(
    S, ix, iy,
    title_x=titles[i], title_y=titles[j],
    similarity="cka", cmap="magma", figsize=(7.5, 6.5),
    show_percent=True,
    vmin_from_data=True,
    extra_title=" - Both SAM, different Seeds"
)

In [None]:
i, j = 1, 3

# 1) Compute
S, ix, iy = compute_cross_epoch_similarity(
    runs[i], runs[j],
    mode="embeddings", similarity="cka",
    skip=10, #start_epoch=40,
    desc_prefix=f"{titles[i]} vs {titles[j]} — "
)

# 2) Plot
plot_cross_epoch_similarity_heatmap(
    S, ix, iy,
    title_x=titles[i], title_y=titles[j],
    similarity="cka", cmap="magma", figsize=(7.5, 6.5),
    show_percent=True,
    vmin_from_data=True,
    extra_title=" - Both SGD, different Seeds"
)

## Prediction Similarity

### Multiple paths per epoch

In [None]:
from helper.visualization import compute_prediction_similarities
from helper.plots import plot_prediction_similarity_heatmap

In [None]:
%matplotlib ipympl
%matplotlib widget
similarities = compute_prediction_similarities(runs, similarity="cosine")
plot_prediction_similarity_heatmap(similarities, run_titles=titles)

### Predictions Run Comparison

In [None]:
from helper.visualization import compute_prediction_cross_epoch_similarity
from helper.plots import plot_prediction_cross_epoch_heatmap 

i,  j = 1, 3

# 1) Cosine similarity on predictions, skip every 5 epochs from 30 onward
S_pred, ix, iy = compute_prediction_cross_epoch_similarity(
    runs[i], runs[j],
    metric="cosine",
    desc_prefix=f"{titles[i]} vs {titles[j]} — "
)
plot_prediction_cross_epoch_heatmap(
    S_pred, ix, iy,
    title_x=titles[i], title_y=titles[j],
    metric="cosine", show_percent=False, percent_decimals=1,
    vmin_from_data=True
)

### Predictions MPHATE

In [None]:
import matplotlib
matplotlib.pyplot.close()

In [None]:
similarities = compute_prediction_similarities(runs[::-2], similarity="cosine")
plot_prediction_similarity_heatmap(similarities, run_titles=titles[::-2], figsize=(6,5))

In [None]:
from helper.visualization import mphate_on_predictions

In [None]:
from helper.visualization import Animation

def mphate_on_predictions(runs, titles=None):
    """
    Apply M-PHATE to the prediction distributions (val_distributions) across runs.
    Returns a list of Animation objects, one per run, with projections over epochs.
    """
    import m_phate

    all_run_flattened = []

    for run in runs:
        preds_per_epoch = run.results["val_distributions"]  # list of (samples, classes)
        flattened_epochs = [pred.reshape(-1) for pred in preds_per_epoch]  # shape: (samples * classes,)
        flattened_tensor = np.stack(flattened_epochs)  # shape: (epochs, flat_dim)
        all_run_flattened.append(flattened_tensor[10:])

    all_run_flattened = np.stack(all_run_flattened)  # shape: (n_runs, epochs, flat_dim)
    combined_pred = np.transpose(all_run_flattened, (1, 0, 2))  # (epochs, runs, features)

    # Run M-PHATE
    mphate_op = m_phate.M_PHATE(knn_dist="cosine", mds_dist="cosine")
    mphate_emb = mphate_op.fit_transform(combined_pred)  # shape: (epochs * runs, 2)

    n_epochs, n_runs = combined_pred.shape[:2]
    mphate_emb = mphate_emb.reshape(n_epochs, n_runs, 2)
    mphate_trajectories = np.transpose(mphate_emb, (1, 0, 2))  # shape: (runs, epochs, 2)

    # Wrap into Animation objects
    animations = []
    for idx, run in enumerate(runs):
        title = titles[idx] if titles is not None else run.results["train_config"]
        anim = Animation(
            projections=mphate_trajectories[idx],
            title=title,
            run=run
        )
        animations.append(anim)

    return animations

In [None]:
pred_animations = mphate_on_predictions(runs, titles=titles)

In [None]:
plot_phate_animations(pred_animations, start_epoch=5, legend_dist=-1.0)