In [None]:
import matplotlib
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np

In [None]:
import phate, m_phate, umap # Should throw no error in proper environment

#### Embedding Visualization TODOs
Data
- [x] PCA
- [ ] Use ALL train data for PCA?
- [x] PCA Denoising eval on denoised Embedding drift

- [x] t-SNE
- [ ] Visualize t-SNE training Steps
- [ ] Refine with new paper

- [x] UMAP
- [x] UMAP Parameters

- [x] Copy PHATE stuff
- [x] Evaluate

Visualization
- [x] Live during training
- [x] 3D
- [x] With trajectory

**Environment:** Please use another environment (`phate-env`) here

# Embedding Visualization

## Import Data

In [None]:
# ==== MNIST ========
dataset = "mnist"

run_id = "run-0011-CNN_mnist_32_0.9776"
#run_id = "run-0012-CNN_mnist_32_0.9768"
#run_id = "run-0013-CNN_mnist_32_0.9797"
#run_id = "run-0014-CNN_mnist_32_0.9744"

In [None]:
# ==== CIFAR 10 ========
dataset = "cifar10"

# Residual
run_id = "run-0016-CNN_cifar10_128_0.8093" # Seed 42, SAM
# run_id = "run-0018-CNN_cifar10_128_0.8499" # Seed 42
# run_id = "run-0020-CNN_cifar10_128_0.8079" # Seed 11, SAM
# run_id = "run-0022-CNN_cifar10_128_0.8519" # Seed 11
    
# No Residual
# run_id = "run-0017-CNN_cifar10_128_0.8072" # Seed 42, SAM
# run_id = "run-0019-CNN_cifar10_128_0.8487" # Seed 42
# run_id = "run-0021-CNN_cifar10_128_0.8054" # Seed 11, SAM
# run_id = "run-0023-CNN_cifar10_128_0.8509" # Seed 11

In [None]:
from helper.visualization import Run
run = Run(run_id, dataset)

## Recap: The Training

In [None]:
fig = run.plot_training_records()

### Confusion Matrix Development

In [None]:
%matplotlib ipympl
%matplotlib widget
_ = run.confusion_matrix(annotate=True)

## Embedding Drift

The evaluation measure in this work
- **Multi-scale skips**: for each snapshot index `i`, compare its embedding `E_i` to earlier snapshots `E_{i - 2**n}` for `n = 0,1,…,4` (skip lengths 1, 2, 4, 8, 16).
- **Mean Euclidean distance**:
  ```python
  drift = np.linalg.norm(current_snapshot - previous_snapshot, axis=1).mean()
- **Result:** a dict mapping each skip length to a time series of drift values, showing how rapidly—and at what scales—the embedding space is evolving.

With euclidean distance:

In [None]:
_ = run.plot_embedding_drifts() #(y_lim=6)

In [None]:
_ = run.plot_embedding_drift_multi()

#### Manhattan distance
The Manhattan distance, compared to the (scaled) Euclidean distance is **almost the same**

In [None]:
run.plot_embedding_drifts_manhattan()

## CKA Similarities

This plot shows **1 − CKA similarity** over time, representing the **structural change** in the embedding space.
Lower values indicate high similarity (stable structure), while higher values reflect greater representational drift.
It allows direct comparison with Euclidean embedding drift and helps identify when and how much the internal structure evolves during training.

In [None]:
run.plot_cka_similarities(y_lim=0.3)

## Eigenvalue development

This plot shows the **10 top PCA eigenvalues** of the embedding space over training time.
Each curve represents the variance explained by a principal direction.
Changes in the eigenvalue spectrum reveal how the dimensional structure of the embeddings evolves — e.g., early compression, later expansion, or stabilization of representational capacity.

In [None]:
run.eigenvalues(figsize=(8, 5))

### Mean Trajectory Curvature
This plot shows the mean trajectory curvature of sample embeddings over training time, with shaded bands indicating ±1 standard deviation across samples.

In the early epochs, curvature is relatively low, indicating smooth and directional changes in the embedding space. Later, the curvature rises and stabilizes in a higher range (~1.5–2.5 radians), suggesting that sample trajectories become increasingly erratic and less coherent.

This indicates that the model transitions from a stable learning phase into a regime where embeddings frequently change direction, reflecting either ongoing internal restructuring, unstable optimization, or lack of convergence in representation space. The high standard deviation shows that this behavior is not uniform across samples, with some embeddings changing more drastically than others.

In [None]:
run.plot_curvature_distribution()

# Visualizations PCA

### PCA visualizations can be based on different bases

In [None]:
from helper.visualization import generate_pca_animation

In [None]:
ani_pca_first = generate_pca_animation(run, fit_basis='first')
ani_pca_last = generate_pca_animation(run, fit_basis='last')
ani_pca_all = generate_pca_animation(run, fit_basis='all')
ani_pca_window = generate_pca_animation(run, fit_basis='window', window_size=16)

### Visualization

The 2D visualization here has a slider for epochs. You can also press Play, Pause and Stop

Optionally, a translation between the noisy steps can be activated to better track points moving far distances.

In [None]:
# CIFAR100 Legend
from helper.plots import show_cifar100_legend
if dataset == "cifar100":
    show_cifar100_legend(cmap = "tab20")

In [None]:
%matplotlib ipympl
%matplotlib widget
from helper.visualization import show_animations

show_animations(
    animations=[ani_pca_first, ani_pca_last, ani_pca_all, ani_pca_window],
    interpolate=False, # TRANSLATION via linear interpolation
    steps_per_transition=2, # interpolation steps
    figsize_per_plot=(4, 4),
    alpha=0.8,
    dot_size=6, # 12
    cols=2
)

In [None]:
show_animations(
    animations=[ani_pca_first, ani_pca_last, ani_pca_all, ani_pca_window],
    figsize_per_plot=(4, 4),
    add_confusion_matrix=True
)

In [None]:
#ani_pca_window.save_as_gif()

In [None]:
ani_pca_first.evaluate()
ani_pca_last.evaluate()
ani_pca_all.evaluate()
ani_pca_window.evaluate()

In [None]:
%matplotlib ipympl
%matplotlib widget
ani_pca_all.scatter_movements()

In [None]:
ani_pca_all.evaluate_movements()

In [None]:
ani_pca_window.evaluate_movements()

### Denoising

As the embedding snapshots during training are made within one epoch at fixed, but arbitrary intervals, with varying samples and potentially augmented images, they are very noisy.

As a result, the values can only be seen as an indicator, not as an exact measurement of the embedding.

Therefore, we can apply denoising to get a better overall picture

In [None]:
ani_pca_first_denoised = ani_pca_first.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=False, do_cka_similarities=False)
ani_pca_last_denoised = ani_pca_last.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=False, do_cka_similarities=False)
ani_pca_all_denoised = ani_pca_all.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=False, do_cka_similarities=False)
ani_pca_window_denoised = ani_pca_window.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=False, do_cka_similarities=False)

In [None]:
ani_pca_window_denoised.plot(
    interpolate=True,
    steps_per_transition=1,
    alpha=1,
)

In [None]:
show_animations(
    animations=[
        ani_pca_first_denoised,
        ani_pca_last_denoised,
        ani_pca_all_denoised,
        ani_pca_window_denoised],
)

In [None]:
ani_pca_all_denoised.evaluate()

#### Now compared to denoised Embeddings...

In [None]:
ani_pca_first_denoised = ani_pca_first.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=True)
ani_pca_last_denoised = ani_pca_last.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=True)
ani_pca_all_denoised = ani_pca_all.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=True)
ani_pca_window_denoised = ani_pca_window.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=True)

In [None]:
ani_pca_first_denoised.evaluate()
ani_pca_last_denoised.evaluate()
ani_pca_all_denoised.evaluate()
ani_pca_window_denoised.evaluate()

In [None]:
ani_pca_all.evaluate_movements()
ani_pca_all_denoised.evaluate_movements()

In [None]:
ani_pca_window.evaluate_movements()
ani_pca_window_denoised.evaluate_movements()

In [None]:
# Close interactivity of plots before
matplotlib.pyplot.close()

### 3D
We can also visualize 3D...

In [None]:
animation_3D = generate_pca_animation(
    run,
    fit_basis='window',
    out_dim=3 #3D
)

In [None]:
animation_3D = animation_3D.denoise(do_embedding_drift=True)

In [None]:
from helper.plots import show_with_slider_3d

show_with_slider_3d(
    animation_3D.projections,
    labels=animation_3D.labels,
    interpolate=False,
    steps_per_transition=1,
    alpha=0.7,
    dataset=animation_3D.run.dataset,
    show_legend=False,
    dot_size=10, #20
)

In [None]:
animation_3D.evaluate()

### Other Denoising Strategies

In [None]:
#projections = projections_pca_window
animation = ani_pca_all

In [None]:
animation.evaluate()

To smooth the low-dimensional projections we use two denoising modes:

- **Exponential (causal) blending**  
  Recursively mix each frame $P_i$ with the previous denoised output $D_{i-1}$:  
  $$D_i = (1-\alpha)\,P_i + \alpha\,D_{i-1}$$  
  Reacts quickly while damping high-frequency noise.

- **Window (moving-average) blending**  
  Compute the mean of the last $w$ raw projections:  
  $$\overline{P}_i = \frac{1}{w}\sum_{j=i-w+1}^{i}P_j$$  
  and blend it with $P_i$:  
  $$D_i = (1-\alpha)\,P_i + \alpha\,\overline{P}_i$$
  Uses surrounding frames for stronger smoothing at the cost of lag.


In [None]:
denoised_window = animation.denoise(window_size=15, blend=0.9, mode='window')
denoised_exponential = animation.denoise(blend=0.8, mode='exponential')

In [None]:
show_animations(
    animations=[
        animation,
        denoised_window,
        denoised_exponential],
    custom_titles=["PCA", "PCA denoised window", "PCA denoised exponential"]
)

In [None]:
denoised_window.evaluate()
denoised_exponential.evaluate()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define parameters
window_sizes = [1, 2, 4, 8, 10, 15, 20, 30]
blend_values = np.linspace(0, 1, 11)
correlation_results = {ws: [] for ws in window_sizes}
exponentials = []

# Run correlations
for blend in blend_values:
    for ws in window_sizes:
        corr = animation.denoise(window_size=ws, blend=blend, mode='window').evaluate(verbose=False)
        correlation_results[ws].append(corr)

    corr = animation.denoise(blend=blend, mode='exponential').evaluate(verbose=False)
    exponentials.append(corr)

In [None]:
# Plotting
plt.figure(figsize=(8, 4))
for ws in window_sizes:
    plt.plot(blend_values, correlation_results[ws], label=f'window_size={ws}')
plt.plot(blend_values, exponentials, label=f'exponential', linewidth=3)
plt.xlabel("Blend")
plt.ylabel("Correlation")
plt.title("Correlation vs. Blend for different denoise calculations")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Define parameters
window_sizes = range(1, 30)
correlation_results = []
blend = 0.9

# Run correlations
for ws in window_sizes:
    corr = animation.denoise(window_size=ws, blend=blend, mode='window').evaluate(verbose=False)
    correlation_results.append(corr)

# Plotting
plt.figure(figsize=(8, 4))
plt.plot(window_sizes, correlation_results)
plt.xlabel("Window Size")
plt.ylabel("Correlation")
plt.title("Correlation vs. Window Size")
plt.grid(True)
plt.show()

# t-SNE Visualization

In [None]:
# Close interactivity of plots before
matplotlib.pyplot.close()

## Standard t-SNE


In [None]:
# Restore original values
#run = Run(run_id, dataset)
#run.print_info()

In [None]:
# t-SNE is computationally very intense
#run.subsample(point_step=5, snapshot_step=1)
#run.print_info()

In [None]:
# As a comparison
from helper.visualization import  generate_pca_animation
pca_animation = generate_pca_animation(run, fit_basis='all').denoise(blend=0.8, mode='exponential')

In [None]:
from helper.visualization import generate_tsne_animation, show_animations

tsne_animation = generate_tsne_animation(run)

In [None]:
%matplotlib ipympl
%matplotlib widget

show_animations(
    [tsne_animation, pca_animation],
    interpolate=True,
    steps_per_transition=5
)

In [None]:
tsne_animation.evaluate()

In [None]:
#tsne_animation.save_as_gif(frame_interval=100, steps_per_transition=5)

In [None]:
denoised_tsne = tsne_animation.denoise(blend=0.8, mode='exponential')

In [None]:
show_animations([tsne_animation, denoised_tsne, pca_animation])

In [None]:
denoised_tsne.evaluate()

In [None]:
from helper.visualization import show_projections_and_drift

show_animations([pca_animation, denoised_tsne, tsne_animation], with_drift=True)

## t-SNE with update blending

After computing t-SNE for frame $i$, you blend it with the previous projection.
`tsne_update` is a weight in $[0,1]$:
- If `tsne_update=1`, it's exactly the original method (no blending).
- If `tsne_update=0`, you freeze to the previous frame (no update at all).
- Values like `0.2 - 3` should result in a smooth interpolation between old and new.

Result: Extra smoothing over time.
- Avoids too fast movements
- Avoids flips of dense clusters
- Can introduce lag or “stickiness,” but animations look steadier.

In [None]:
tsne_avg = generate_tsne_animation(
    run,
    tsne_update=0.3
)

In [None]:
show_animations(
    [tsne_avg, tsne_animation],
    #interpolate=True,
    #steps_per_transition=5,
)

In [None]:
# Denoise the evaluation
tsne_avg.denoise(do_projections=False, blend=0.7, mode='exponential').evaluate()

## t-SNE with backwards computation
Starts with the last frame
- Ensures convergence
- Improves noise in early frames as the basis is from a later and better picture

In [None]:
tsne_reverse = generate_tsne_animation(
    run,
    reverse_computation=True
)

In [None]:
show_animations(
    [tsne_reverse, tsne_animation],
    interpolate=True,
    steps_per_transition=5
)

In [None]:
tsne_reverse_2 = generate_tsne_animation(
    run,
    reverse_computation=True,
    tsne_update=0.2
)

In [None]:
show_animations(
    [tsne_reverse_2, tsne_reverse, tsne_animation],
    interpolate=True,
    steps_per_transition=5,
    #with_drift=True
)

In [None]:
tsne_reverse.evaluate()

In [None]:
tsne_reverse_2.denoise(do_projections=False, blend=0.7, mode='exponential').evaluate()

## t-SNE with Cosine metric

In [None]:
tsne_cosine = generate_tsne_animation(
    run,
    tsne_update=0.2,
    metric='cosine'
)

In [None]:
show_animations(
    [tsne_cosine, tsne_avg],
    interpolate=True,
    steps_per_transition=1,
)

In [None]:
show_projections_and_drift(
    projections_list = [tsne_cosine, tsne_animation],
    interpolate=True,
    steps_per_transition=5,
)

In [None]:
tsne_cosine.denoise(do_projections=False, blend=0.7, mode='exponential').evaluate()

## t-SNE with random Seed
Doesen't affect animation

In [None]:
tsne_random = generate_tsne_animation(
    run,
    random_state=1106,
    tsne_init='random'
)

In [None]:
show_animations(
    [tsne_random, tsne_animation],
    interpolate=True,
    steps_per_transition=5
)

In [None]:
tsne_random.evaluate()

## t-SNE Perplexity
Compare 5 - 10 - 30 - 50

In [None]:
tsne_p_5 = generate_tsne_animation(run, tsne_perplexity=5, metric='cosine')
tsne_p_10 = generate_tsne_animation(run, tsne_perplexity=10, metric='cosine')
tsne_p_30 = generate_tsne_animation(run, tsne_perplexity=30, metric='cosine')
tsne_p_50 = generate_tsne_animation(run, tsne_perplexity=50, metric='cosine')

In [None]:
show_animations(
    [
        tsne_p_5,
        tsne_p_10,
        tsne_p_30,
        tsne_p_50
    ],
    custom_titles=[
        "t-SNE Perplexity 5",
        "t-SNE Perplexity 10",
        "t-SNE Perpl. 30 (standard)",
        "t-SNE Perplexity 50"
    ],
    shared_axes=False
)

# Dynamic t-SNE
This is an implementation of Rauber et. al.
https://github.com/paulorauber/thesne/

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
import time


class DynamicTSNE:
    def __init__(
            self,
            output_dims=2,
            verbose=True,
    ):
        self.output_dims = output_dims
        self.verbose = verbose
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def compute_affinities(self, Xs, perplexity=30.0, k_neighbors=90):
        def Hbeta(D, beta):
            P = torch.exp(-D * beta)
            sumP = torch.sum(P)
            sumP = torch.clamp(sumP, min=1e-8)
            H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
            P = torch.clamp(P / sumP, min=1e-8)
            return H, P

        def compute_P(X, init_beta=None):
            t0 = time.time()

            n = X.shape[0]
            D = torch.cdist(X, X, p=2).pow(2)
            P = torch.zeros((n, n), device=X.device)
            beta = init_beta.clone() if init_beta is not None else torch.ones(n, device=X.device)
            logU = torch.log(torch.tensor(perplexity, device=X.device))
            all_tries = 0

            for i in range(n):
                distances = D[i]
                topk = torch.topk(distances, k=k_neighbors + 1, largest=False)
                idx = topk.indices[topk.indices != i][:k_neighbors]
                Di = torch.clamp(distances[idx], max=1e3)

                betamin, betamax = None, None
                H, thisP = Hbeta(Di, beta[i])
                tries = 0
                while torch.abs(H - logU) > 1e-5 and tries < 50:
                    if H > logU:
                        betamin = beta[i].clone()
                        beta[i] = beta[i] * 2 if betamax is None else (beta[i] + betamax) / 2
                    else:
                        betamax = beta[i].clone()
                        beta[i] = beta[i] / 2 if betamin is None else (beta[i] + betamin) / 2
                    H, thisP = Hbeta(Di, beta[i])
                    tries += 1
                all_tries += tries
                P[i, idx] = thisP

            if self.verbose:
                print(f"Total affinity computation time: {time.time() - t0:.2f}s, {all_tries / n} Tries")

            P = (P + P.T) / (2 * n)
            return P, beta

        X_tensor = [torch.tensor(X, device=self.device) for X in Xs]
        self.Xs = X_tensor

        Ps = []
        prev_beta = None
        for X in X_tensor:
            P, prev_beta = compute_P(X, prev_beta)
            Ps.append(P)

        self.Ps = torch.stack(Ps)
        assert not torch.isnan(self.Ps).any(), "Affinity matrix has NaN"

    def fit(self, n_epochs=1000, exaggeration=12.0, exaggeration_epochs=250, lr=200.0, lambd=0.1):
        T = len(self.Xs)
        n = self.Xs[0].shape[0]

        Y_init = []
        for X in self.Xs:
            X_cpu = X.detach().cpu().numpy()
            pca = PCA(n_components=self.output_dims)
            Y_pca = pca.fit_transform(X_cpu)
            Y_init.append(torch.tensor(Y_pca, device=self.device, dtype=torch.float32))
        
        Y = torch.stack(Y_init)
        Y.requires_grad_()

        optimizer = Adam([Y], lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

        for epoch in range(n_epochs):
            optimizer.zero_grad()
            total_loss = 0

            if epoch < exaggeration_epochs:
                P_use = self.Ps * exaggeration
            else:
                lambd = 0
                P_use = self.Ps
            
            for t in range(T):
                Qt, _ = self._compute_lowdim_affinities(Y[t])
                loss = self._kl_divergence(P_use[t], Qt)
                if t > 0:
                    loss += (lambd / (2 * n)) * F.mse_loss(Y[t], Y[t - 1])
                total_loss += loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_([Y], max_norm=10.0)
            optimizer.step()
            scheduler.step()

            if self.verbose and (epoch % 100 == 0 or epoch == n_epochs - 1):
                print(f"Epoch {epoch}, Loss: {total_loss.item():.4f}")

        return [Y[t].detach().cpu().numpy() for t in range(T)]

    def _compute_lowdim_affinities(self, Y):
        num = 1 / (1 + torch.cdist(Y, Y, p=2).pow(2))
        num.fill_diagonal_(0.0)
        Q = torch.clamp(num / num.sum(), min=1e-5)
        return Q, num

    def _kl_divergence(self, P, Q):
        return torch.sum(P * torch.log((P + 1e-8) / (Q + 1e-8)))

In [None]:
tsne = DynamicTSNE()
tsne.compute_affinities(run.embeddings, perplexity=5.0, k_neighbors=150)

In [None]:
from sklearn.decomposition import PCA
projections = tsne.fit(lr=200, lambd=0.1, n_epochs=1000, exaggeration_epochs=250, exaggeration=22.0)

In [None]:
%matplotlib widget
from helper.plots import show_with_slider

show_with_slider(
    projections,
    labels=run.labels,
    interpolate=True,
    steps_per_transition=4,
)

In [None]:
matplotlib.pyplot.close()

In [None]:
visualization_drift_vs_embedding_drift(projections, run.embedding_drifts)

# MODERN DYNAMIC TNSNE

In [None]:
modern_dynamic_tsne = ModernDynamicTSNE(
    n_epochs=500,
    perplexity=50,
)
projections_3 = modern_dynamic_tsne.fit_transform(run.embeddings)

In [None]:
from helper.plots import show_multiple_projections_with_slider

show_multiple_projections_with_slider(
    projections_list=[tsne_p_5_blend.projections, projections_3],
    labels=run.labels,
    titles=["t-SNE", "Dynamic t-SNE"],
    interpolate=False,
    figsize_per_plot=(4, 4),
    dataset=dataset,
    shared_axes=False
)

In [None]:
import numpy as np
from sklearn.decomposition import PCA
import torch
import torch.nn.functional as F
from torch.optim import SGD


class ModernDynamicTSNE:
    def __init__(
        self,
        perplexity=30,
        n_epochs=1000,
        output_dims=2,
        initial_lr=2400,
        final_lr=200,
        lr_switch=250,
        init_stdev=1e-4,
        initial_momentum=0.5,
        final_momentum=0.8,
        momentum_switch=250,
        lmbda=0.0,
        sigma_iters=50,
        verbose=True,
        device=None
    ):
        self.perplexity = perplexity
        self.n_epochs = n_epochs
        self.output_dims = output_dims
        self.initial_lr = initial_lr
        self.final_lr = final_lr
        self.lr_switch = lr_switch
        self.init_stdev = init_stdev
        self.initial_momentum = initial_momentum
        self.final_momentum = final_momentum
        self.momentum_switch = momentum_switch
        self.lmbda = lmbda
        self.sigma_iters = sigma_iters
        self.verbose = verbose
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    def _hbeta(self, D, beta):
        P = torch.exp(-D * beta)
        sumP = torch.sum(P)
        sumP = torch.clamp(sumP, min=1e-8)
        H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
        P = P / sumP
        return H, P

    def _binary_search_perplexity(self, D, tol=1e-5):
        n = D.shape[0]
        sigmas = torch.ones(n, device=self.device)
        P = torch.zeros((n, n), device=self.device)

        logU = np.log(self.perplexity)
        for i in range(n):
            betamin = None
            betamax = None
            beta = sigmas[i]
            Di = D[i][torch.arange(n) != i]
            H, thisP = self._hbeta(Di, beta)

            tries = 0
            while torch.abs(H - logU) > tol and tries < self.sigma_iters:
                if H > logU:
                    betamin = beta
                    beta = beta * 2 if betamax is None else (beta + betamax) / 2
                else:
                    betamax = beta
                    beta = beta / 2 if betamin is None else (beta + betamin) / 2
                H, thisP = self._hbeta(Di, beta)
                tries += 1
            P[i, torch.arange(n) != i] = thisP
        return (P + P.T) / (2 * n)

    def _precompute_Ps(self, Xs):
        Ps = []
        for X in Xs:
            D = torch.cdist(X, X).pow(2)
            P = self._binary_search_perplexity(D)
            Ps.append(P)
        return Ps

    def _compute_cost(self, Ys, Ps):
        total_kl = 0
        for Y, P in zip(Ys, Ps):
            Q_num = 1 / (1 + torch.cdist(Y, Y).pow(2))
            Q_num.fill_diagonal_(0)
            Q = Q_num / Q_num.sum()
            kl = torch.sum(P * torch.log((P + 1e-8) / (Q + 1e-8)))
            total_kl += kl
        smoothness = sum((Ys[i] - Ys[i + 1]).pow(2).sum() for i in range(len(Ys) - 1))
        return total_kl + self.lmbda * smoothness / (2 * Ys[0].shape[0])

    def fit_transform(self, Xs_np):
        Xs = [torch.tensor(X, device=self.device, dtype=torch.float32) for X in Xs_np]
        T = len(Xs)
        N = Xs[0].shape[0]

        # Init Ys with PCA
        Ys = [
            torch.tensor(PCA(n_components=self.output_dims).fit_transform(X.cpu().numpy()),
                         device=self.device, dtype=torch.float32, requires_grad=True)
            for X in Xs
        ]

        # Precompute all P matrices once
        Ps = self._precompute_Ps(Xs)

        optimizer = SGD(Ys, lr=self.initial_lr, momentum=self.initial_momentum)

        for epoch in range(self.n_epochs):
            if epoch == self.lr_switch:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = self.final_lr
            if epoch == self.momentum_switch:
                for param_group in optimizer.param_groups:
                    param_group['momentum'] = self.final_momentum

            optimizer.zero_grad()
            loss = self._compute_cost(Ys, Ps)
            loss.backward()
            optimizer.step()

            if self.verbose and (epoch % 100 == 0 or epoch == self.n_epochs - 1):
                print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

        return [Y.detach().cpu().numpy() for Y in Ys]

# UMAP

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.utils.deprecation")

In [None]:
# Restore original values
#run = Run(run_id, dataset)
#run.print_info()

In [None]:
#run.subsample(point_step=5, snapshot_step=5)
#run.print_info()

In [None]:
from helper.visualization import generate_pca_animation, generate_umap_animation

pca_animation = generate_pca_animation(run, fit_basis='all', ).denoise(blend=0.8, mode='exponential')

print(f"{len(run.embeddings)} - 1 should be {len(run.embedding_drifts[1])}")

In [None]:
umap_animation = generate_umap_animation(
    run,
    fit_basis='all_n', # 'all_n' includes every n-th embedding
    fit_basis_n=10,
)

In [None]:
%matplotlib ipympl
%matplotlib widget
from helper.visualization import show_animations

show_animations(
    [
        umap_animation,
        pca_animation,
    ],
    shared_axes=False,
    add_confusion_matrix=True
)

In [None]:
umap_animation.evaluate(y_lim=4)

### UMAP Hyperparameters

In [None]:
#run.subsample(point_step=5, snapshot_step=5)
#run.print_info()

In [None]:
# Current standard
umap_ani = umap_animation

#### Cosine Metric

In [None]:
umap_cosine = generate_umap_animation(
    run,
    metric='cosine'
)

In [None]:
show_animations([
        umap_ani,
        umap_cosine,
    ], shared_axes=False)

In [None]:
umap_cosine.evaluate()

In [None]:
umap_ani.denoise().evaluate()
umap_cosine.denoise().evaluate()

#### Number of Neighbors
Local vs global structure

This determines the number of neighboring points used in local approximations of manifold structure. Larger values will result in more global structure being preserved at the loss of detailed local structure. In general this parameter should often be in the range 5 to 50, with a choice of 10 to 15 being a sensible default. (From documentation)

In [None]:
umap_neighbors_5 = generate_umap_animation(
    run,
    metric='cosine',
    fit_basis_n=10,
    n_neighbors=5
)
umap_neighbors_15 = umap_cosine
umap_neighbors_30 = generate_umap_animation(
    run,
    metric='cosine',
    fit_basis_n=10,
    n_neighbors=30
)
umap_neighbors_50 = generate_umap_animation(
    run,
    metric='cosine',
    fit_basis_n=10,
    n_neighbors=50
)

In [None]:
umap_neighbors_30 = generate_umap_animation(
    run,
    metric='cosine',
    fit_basis_n=10,
    n_neighbors=30
)

In [None]:
%matplotlib ipympl
%matplotlib widget
from helper.visualization import show_animations

show_animations(
    [
        umap_neighbors_5,
        umap_neighbors_15,
        umap_neighbors_30,
        umap_neighbors_50,
    ],
    interpolate=True,
    shared_axes=False,
    cols=4
)

In [None]:
umap_neighbors_15.evaluate()
umap_neighbors_50.evaluate()

In [None]:
umap_n_05_denoised = umap_neighbors_5.denoise(window_size=15, blend=0.99, mode='window')
umap_n_15_denoised = umap_neighbors_15.denoise(window_size=15, blend=0.99, mode='window')
umap_n_50_denoised = umap_neighbors_50.denoise(window_size=15, blend=0.99, mode='window')

In [None]:
show_animations(
    [
        umap_n_05_denoised,
        umap_n_15_denoised,
        umap_n_50_denoised,
    ],
    shared_axes=False)

In [None]:
umap_n_15_denoised.evaluate()
umap_n_50_denoised.evaluate()

#### Min Dist
This controls how tightly the embedding is allowed compress points together. Larger values ensure embedded points are more evenly distributed, while smaller values allow the algorithm to optimise more accurately with regard to local structure. Sensible values are in the range 0.001 to 0.5, with 0.1 being a reasonable default. (From documentation)

In [None]:
umap_dist_001 = generate_umap_animation(
    run,
    metric='cosine',
    min_dist=0.001
)
umap_dist_01 = generate_umap_animation(
    run,
    metric='cosine',
    min_dist=0.01
)
umap_dist_1 = umap_cosine
umap_dist_5 = generate_umap_animation(
    run,
    metric='cosine',
    min_dist=0.5
)

In [None]:
show_animations([
        umap_dist_001,
        umap_dist_01,
        umap_dist_1,
        umap_dist_5
    ], cols=4, shared_axes=False)

In [None]:
umap_dist_5.evaluate(y_lim=2)

In [None]:
umap_dist_5.denoise().evaluate(y_lim=2)

#### Educated Guess choice
- Neighbors? ~15 shows best dynamics
- Large min_dist [0.1, 0.5] reveals inner-cluster dynamics most
- Cosine seems to emphasize different changes and builds different-shaped clusters - the Scores for Embedding and visualization drift are also better.

In [None]:
umap_cosine_n20 = generate_umap_animation(
    run,
    fit_basis='all_n',
    n_neighbors=20,
    metric='cosine',
    min_dist=0.2
)

In [None]:
show_animations(
    [
        umap_cosine_n20,
        umap_cosine,
    ],
    shared_axes=False,
    add_confusion_matrix=True,
    annotate_confusion_matrix=True,
)

In [None]:
umap_cosine_n20.evaluate()

In [None]:
umap_cosine_n20.denoise().evaluate(y_lim=2)

# PHATE

In [None]:
matplotlib.pyplot.close()

In [None]:
run_sub = Run(run_id, dataset).subsample(point_step=5, snapshot_step=2).print_info()

In [None]:
def generate_phate(
        run: Run,
        max_frames=None,
        out_dim=2,
        knn=5,
        decay=40,
        t=20,
        n_jobs=-1,
        random_state=42,
        window=0  # number of frames before to include for fitting
):
    import phate
    import numpy as np
    from tqdm import tqdm
    from scipy.linalg import orthogonal_procrustes

    embeddings_list = run.embeddings.copy()
    max_frames = max_frames or len(embeddings_list)
    projections = []

    title = f'PHATE (knn={knn}, decay={decay}, t={t}, window={window})'

    prev_projection = None

    for i in tqdm(range(max_frames), desc="PHATE frames"):
        # Determine window range
        start = max(0, i - window)
        end = min(max_frames, i + 1)
        fit_data = np.concatenate(embeddings_list[start:end], axis=0)

        # Fit PHATE on the windowed data
        phate_op = phate.PHATE(
            n_components=out_dim,
            knn=knn,
            decay=decay,
            t=t,
            n_jobs=n_jobs,
            random_state=random_state,
            verbose=False
        )
        fit_projection = phate_op.fit_transform(fit_data)

        # Extract just the projection for the current frame
        current_len = len(embeddings_list[i])
        offset = sum(len(embeddings_list[j]) for j in range(start, i))
        projection = fit_projection[offset:offset + current_len]

        # Flip and align using orthogonal Procrustes
        if prev_projection is not None:
            R, _ = orthogonal_procrustes(projection, prev_projection)
            projection = projection @ R

        projections.append(projection)
        prev_projection = projection.copy()

    return Animation(projections=projections, title=title, run=run)

In [None]:
phate_animation = generate_phate(run_sub, knn=30, decay=50, t=30, window=5)

In [None]:
phate_animation.plot(interpolate=True, steps_per_transition=1)

In [None]:
phate_animation.evaluate()

In [None]:
#phate_animation.save_as_gif()

In [None]:
phate_denoised = phate_animation.denoise(window_size=5, blend=0.99, mode='window')

In [None]:
phate_denoised.plot(interpolate=True, steps_per_transition=1)

In [None]:
#phate_denoised.save_as_gif()

In [None]:
phate_denoised.evaluate()

In [None]:
import phate
import matplotlib.pyplot as plt

# Choose Epoch e
total_epochs = len(run.embeddings)
e = 249

emb = run.embeddings[e - 1]
labels = run.labels[0] # [e-1] or [0] doesn't matter

# Apply PHATE
phate_op = phate.PHATE()
emb_phate = phate_op.fit_transform(emb)

# Plot
plt.figure(figsize=(8, 6))
plt.scatter(emb_phate[:, 0], emb_phate[:, 1], c=labels, cmap='tab10', s=5)
plt.title(f"Epoch {e}/{total_epochs}: Embedding space colored by CIFAR10 class (PHATE)")
plt.colorbar(label="Class")
plt.show()

# Takes ~ 5 seconds

### 2) Development of data points through epochs

In [None]:
from helper.visualization import compute_mphate_embeddings, mphate_to_animation

mphate_emb = compute_mphate_embeddings(run)
# Takes ~10 min (scales on samples x epochs)

In [None]:
m_phate_animation = mphate_to_animation(mphate_emb, run)

In [None]:
%matplotlib ipympl
%matplotlib widget
m_phate_animation.plot()

In [None]:
m_phate_animation.evaluate()

In [None]:
m_phate_denoised = m_phate_animation.denoise(window_size=20, blend=0.9, mode='window')
#m_phate_denoised = m_phate_animation.denoise(blend=0.8, mode='exponential')

In [None]:
m_phate_denoised.plot()

In [None]:
#m_phate_denoised.save_as_gif()

In [None]:
m_phate_denoised.evaluate()

### M-PHATE Hyperparameter Exploration


This section explores the effect of key M-PHATE hyperparameters on the temporal embedding structure. We vary one parameter at a time while keeping all others fixed, and visualize the resulting animations side by side.

**Parameters explored:**
- **Diffusion time `t`**: Controls how far information diffuses across the affinity graph. Higher values lead to smoother, more global embeddings. We test `t ∈ {10, 20, 30, 'auto'}`.
- **Interslice KNN `interslice_knn`**: Determines how strongly embeddings are connected across time steps. Larger values enforce stronger temporal smoothness. We test `interslice_knn ∈ {10, 20, 30}`.
- **Distance potential `gamma`**: Modifies the information distance used in the diffusion process. `gamma=0` corresponds to the default square-root potential; higher values approximate PHATE’s log potential. We test `gamma ∈ {0.0, 0.05, 0.1}`.

Each variant is visualized using `show_animations`, optionally with confusion matrices to aid interpretation.


In [None]:
phate_animation = {}

In [None]:
for t in [10, 20, 30, 'auto']:
    emb = compute_mphate_embeddings(run, verbose=False, t=t)
    title = f"M-PHATE (t={t})"
    phate_animation[f"t={t}"] = mphate_to_animation(emb, run, title=title)

In [None]:
from helper.visualization import show_animations

show_animations(
    animations=[
        phate_animation["t=10"],
        phate_animation["t=20"],
        phate_animation["t=30"],
        phate_animation["t=auto"]
    ],
    figsize_per_plot=(4, 4),
    add_confusion_matrix=True
)

In [None]:
phate_animation["t=10"].evaluate()
phate_animation["t=20"].evaluate()
phate_animation["t=30"].evaluate()
phate_animation["t=auto"].evaluate()

In [None]:
for knn in [10, 20, 30]:
    emb = compute_mphate_embeddings(run, verbose=False, interslice_knn=knn)
    title = f"M-PHATE (interslice_knn={knn})"
    phate_animation[f"knn={knn}"] = mphate_to_animation(emb, run, title=title)

In [None]:
show_animations(
    animations=[
        phate_animation["knn=10"],
        phate_animation["knn=20"],
        phate_animation["knn=30"]
    ],
    figsize_per_plot=(4, 4),
    add_confusion_matrix=True
)

In [None]:
phate_animation["knn=10"].evaluate(figsize=(10, 5))
phate_animation["knn=20"].evaluate(figsize=(10, 5))
phate_animation["knn=30"].evaluate(figsize=(10, 5))

In [None]:
from helper.plots import show_phate_graphs

show_phate_graphs(
    animations=[
        phate_animation["knn=10"],
        phate_animation["knn=20"],
        phate_animation["knn=30"]
    ],
    figsize_per_plot=(5, 5),
    start_epoch=150,
    end_epoch=250,
    show_class_coloring=False,
    show_epoch_coloring=True,
    point_size=2
)

In [None]:
for gamma in [0.0, 0.05, 0.1, 0.2]:
    emb = compute_mphate_embeddings(run, verbose=False, gamma=gamma)
    title = f"M-PHATE (gamma={gamma})"
    phate_animation[f"gamma={gamma}"] = mphate_to_animation(emb, run, title=title)

In [None]:
show_animations(
    animations=[
        phate_animation["gamma=0.0"],
        phate_animation["gamma=0.05"],
        phate_animation["gamma=0.1"],
        phate_animation["gamma=0.2"]
    ],
    figsize_per_plot=(4, 4),
    add_confusion_matrix=True
)

In [None]:
phate_animation["gamma=0.0"].evaluate()
phate_animation["gamma=0.05"].evaluate()
phate_animation["gamma=0.1"].evaluate()
phate_animation["gamma=0.2"].evaluate()

In [None]:
from helper.plots import show_phate_graphs

show_phate_graphs(
    animations=[
        phate_animation["gamma=0.0"],
        phate_animation["gamma=0.05"],
        phate_animation["gamma=0.1"],
        phate_animation["gamma=0.2"]
    ],
    figsize_per_plot=(5, 5),
    start_epoch=10,
    end_epoch=30,
    show_class_coloring=False,
    show_epoch_coloring=True,
    point_size=1
)

### Or as single plots:

In [None]:
from helper.plots import plot_mphate_over_time

In [None]:
plot_mphate_over_time(
    mphate_emb,
    run,
    start_epoch=50,
    end_epoch=None,
    show_class_coloring=True,
    show_epoch_coloring=False
)

In [None]:
plot_mphate_over_time(
    mphate_emb,
    run,
    start_epoch=100,
    end_epoch=None,
    show_class_coloring=False,
    show_epoch_coloring=True
)

### Class mean

In [None]:
class_cmap = plt.cm.tab10

# ---- Define your epoch range ----
start_epoch = 0
end_epoch = 200

# ---- Slice data ----
selected_emb = mphate_emb[start_epoch:end_epoch]
selected_epochs = end_epoch - start_epoch
selected_norm = plt.Normalize(vmin=start_epoch, vmax=end_epoch - 1)


unique_classes = np.unique(labels)
n_classes = len(unique_classes)

# --- Compute class centroids over time ---
class_centroids = []
for idx in range(selected_emb.shape[0]):  # over selected epochs
    epoch_emb = selected_emb[idx]
    centroids = []
    for c in unique_classes:
        class_mask = (labels == c)
        centroids.append(epoch_emb[class_mask].mean(axis=0))
    class_centroids.append(np.stack(centroids))

class_centroids = np.stack(class_centroids)  # shape: (epochs, classes, dims)

# --- Plot centroids as lines per class ---
plt.figure(figsize=(10, 8))
for c in range(n_classes):
    plt.plot(
        class_centroids[:, c, 0],
        class_centroids[:, c, 1],
        color=class_cmap(c),
        label=class_names[c]
    )

plt.title(f"Class centroid trajectories (epochs {start_epoch}-{end_epoch})")
plt.xlabel("M-PHATE dim 1")
plt.ylabel("M-PHATE dim 2")
plt.legend(title="Classes", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()