In [None]:
import matplotlib
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np

In [None]:
import phate, m_phate # Should throw no error in proper environment

#### Embedding Visualization TODOs
Data
- [x] PCA
- [ ] Use ALL train data for PCA?
- [ ] PCA Denoising eval on denoised Embedding drift

- [x] t-SNE
- [ ] Visualize t-SNE training Steps
- [ ] Refine with new paper

- [x] UMAP
- [x] UMAP Parameters

- [ ] Copy PHATE stuff
- [ ] Evaluate

Visualization
- [x] Live during training
- [x] 3D
- [ ] With trajectory

**Environment:** Please use another environment (`phate-env`) here

# Embedding Visualization

## Import Data

In [None]:
# ==== MNIST ========
dataset = "mnist"

run = "run-0011-CNN_mnist_32_0.9776"
#run = "run-0012-CNN_mnist_32_0.9768"
#run = "run-0013-CNN_mnist_32_0.9797"
#run = "run-0014-CNN_mnist_32_0.9744"

In [None]:
# ==== CIFAR 10 ========
dataset = "cifar10"

# Residual
#run = "run-0016-CNN_cifar10_128_0.8093" # Seed 42, SAM
# run = "run-0018-CNN_cifar10_128_0.8499" # Seed 42
# run = "run-0020-CNN_cifar10_128_0.8079" # Seed 11, SAM
run = "run-0022-CNN_cifar10_128_0.8519" # Seed 11
    
# No Residual
# run = "run-0017-CNN_cifar10_128_0.8072" # Seed 42, SAM
# run = "run-0019-CNN_cifar10_128_0.8487" # Seed 42
# run = "run-0021-CNN_cifar10_128_0.8054" # Seed 11, SAM
# run = "run-0023-CNN_cifar10_128_0.8509" # Seed 11

In [None]:
run_id = run

In [None]:
from helper.visualization import Run
run = Run(run_id, dataset)

## Recap: The Training

In [None]:
run.plot_training_records()

### Confusion Matrix Development

In [None]:
%matplotlib ipympl
%matplotlib widget
run.confusion_matrix(annotate=True)

## Embedding Drift

The evaluation measure in this work
- **Multi-scale skips**: for each snapshot index `i`, compare its embedding `E_i` to earlier snapshots `E_{i - 2**n}` for `n = 0,1,…,4` (skip lengths 1, 2, 4, 8, 16).
- **Mean Euclidean distance**:
  ```python
  drift = np.linalg.norm(current_snapshot - previous_snapshot, axis=1).mean()
- **Result:** a dict mapping each skip length to a time series of drift values, showing how rapidly—and at what scales—the embedding space is evolving.

In [None]:
run.plot_embedding_drifts()

# Visualizations PCA

### PCA visualizations can be based on different bases

In [None]:
from helper.visualization import generate_projections, visualization_drift_vs_embedding_drift, denoise_projections
from helper.plots import show_multiple_projections_with_slider

In [None]:
ani_pca_first = generate_projections(
    run,
    method='pca',
    pca_fit_basis='first',
)
ani_pca_last = generate_projections(
    run,
    method='pca',
    pca_fit_basis='last',
)
ani_pca_all = generate_projections(
    run,
    method='pca',
    pca_fit_basis='all',
)

In [None]:
ani_pca_window = generate_projections(
    run,
    method='pca',
    pca_fit_basis='window',
    window_size=16,
)

### Visualization

The 2D visualization here has a slider for epochs. You can also press Play, Pause and Stop

Optionally, a translation between the noisy steps can be activated to better track points moving far distances.

In [None]:
# CIFAR100 Legend
from helper.plots import show_cifar100_legend
if dataset == "cifar100":
    show_cifar100_legend(cmap = "tab20")

In [None]:
%matplotlib ipympl
%matplotlib widget
from helper.visualization import show_animations

show_animations(
    animations=[ani_pca_first, ani_pca_last, ani_pca_all, ani_pca_window],
    interpolate=False, # TRANSLATION via linear interpolation
    steps_per_transition=2, # interpolation steps
    figsize_per_plot=(4, 4),
    alpha=0.8,
    dot_size=6 # 12
)

In [None]:
show_animations(
    animations=[ani_pca_first, ani_pca_last, ani_pca_all, ani_pca_window],
    figsize_per_plot=(4, 4),
    add_confusion_matrix=True
)

In [None]:
#ani_pca_window.save_as_gif()

In [None]:
ani_pca_first.evaluate()
ani_pca_last.evaluate()
ani_pca_all.evaluate()
ani_pca_window.evaluate()

In [None]:
# projections_filtered, labels_filtered = filter_classes(projections_pca_window, labels, [4, 30, 55, 72, 95])

### Denoising

As the embedding snapshots during training are made within one epoch at fixed, but arbitrary intervals, with varying samples and potentially augmented images, they are very noisy.

As a result, the values can only be seen as an indicator, not as an exact measurement of the embedding.

Therefore, we can apply denoising to get a better overall picture

In [None]:
ani_pca_first_denoised = ani_pca_first.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=False)
ani_pca_last_denoised = ani_pca_last.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=False)
ani_pca_all_denoised = ani_pca_all.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=False)
ani_pca_window_denoised = ani_pca_window.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=False)

In [None]:
ani_pca_window_denoised.plot(
    interpolate=True,
    steps_per_transition=1,
    alpha=1,
)

In [None]:
show_animations(
    animations=[
        ani_pca_first_denoised,
        ani_pca_last_denoised,
        ani_pca_all_denoised,
        ani_pca_window_denoised],
)

In [None]:
ani_pca_window_denoised.evaluate()

#### Now compared to denoised Embeddings...

In [None]:
ani_pca_first_denoised = ani_pca_first.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=True)
ani_pca_last_denoised = ani_pca_last.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=True)
ani_pca_all_denoised = ani_pca_all.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=True)
ani_pca_window_denoised = ani_pca_window.denoise(window_size=15, blend=0.9, mode='window', do_embedding_drift=True)

In [None]:
ani_pca_first_denoised.evaluate()
ani_pca_last_denoised.evaluate()
ani_pca_all_denoised.evaluate()
ani_pca_window_denoised.evaluate()

In [None]:
# Close interactivity of plots before
matplotlib.pyplot.close()

### 3D
We can also visualize 3D...

In [None]:
animation_3D = generate_projections(
    run,
    method='pca',
    pca_fit_basis='window',
    out_dim=3 #3D
)

In [None]:
animation_3D = animation_3D.denoise(do_embedding_drift=True)

In [None]:
from helper.plots import show_with_slider_3d

show_with_slider_3d(
    animation_3D.projections,
    labels=animation_3D.labels,
    interpolate=False,
    steps_per_transition=1,
    alpha=0.7,
    dataset=animation_3D.run.dataset,
    show_legend=False,
    dot_size=10, #20
)

In [None]:
animation_3D.evaluate()

### Other Denoising Strategies

In [None]:
#projections = projections_pca_window
animation = ani_pca_all

In [None]:
animation.evaluate()

To smooth the low-dimensional projections we use two denoising modes:

- **Exponential (causal) blending**  
  Recursively mix each frame $P_i$ with the previous denoised output $D_{i-1}$:  
  $$D_i = (1-\alpha)\,P_i + \alpha\,D_{i-1}$$  
  Reacts quickly while damping high-frequency noise.

- **Window (moving-average) blending**  
  Compute the mean of the last $w$ raw projections:  
  $$\overline{P}_i = \frac{1}{w}\sum_{j=i-w+1}^{i}P_j$$  
  and blend it with $P_i$:  
  $$D_i = (1-\alpha)\,P_i + \alpha\,\overline{P}_i$$
  Uses surrounding frames for stronger smoothing at the cost of lag.


In [None]:
denoised_window = animation.denoise(window_size=15, blend=0.9, mode='window')
denoised_exponential = animation.denoise(blend=0.8, mode='exponential')

In [None]:
show_animations(
    animations=[
        animation,
        denoised_window,
        denoised_exponential],
    custom_titles=["PCA", "PCA denoised window", "PCA denoised exponential"]
)

In [None]:
denoised_window.evaluate()
denoised_exponential.evaluate()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define parameters
window_sizes = [1, 2, 4, 8, 10, 15, 20, 30]
blend_values = np.linspace(0, 1, 11)
correlation_results = {ws: [] for ws in window_sizes}
exponentials = []

# Run correlations
for blend in blend_values:
    for ws in window_sizes:
        corr = animation.denoise(window_size=ws, blend=blend, mode='window').evaluate(verbose=False)
        correlation_results[ws].append(corr)

    corr = animation.denoise(blend=blend, mode='exponential').evaluate(verbose=False)
    exponentials.append(corr)

In [None]:
# Plotting
plt.figure(figsize=(8, 4))
for ws in window_sizes:
    plt.plot(blend_values, correlation_results[ws], label=f'window_size={ws}')
plt.plot(blend_values, exponentials, label=f'exponential', linewidth=3)
plt.xlabel("Blend")
plt.ylabel("Correlation")
plt.title("Correlation vs. Blend for different denoise calculations")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Define parameters
window_sizes = range(1, 30)
correlation_results = []
blend = 0.9

# Run correlations
for ws in window_sizes:
    corr = animation.denoise(window_size=ws, blend=blend, mode='window').evaluate(verbose=False)
    correlation_results.append(corr)

# Plotting
plt.figure(figsize=(8, 4))
plt.plot(window_sizes, correlation_results)
plt.xlabel("Window Size")
plt.ylabel("Correlation")
plt.title("Correlation vs. Window Size")
plt.grid(True)
plt.show()

# t-SNE Visualization

In [None]:
# Close interactivity of plots before
matplotlib.pyplot.close()

## Standard t-SNE


In [None]:
# Restore original values
run = Run(run_id, dataset)
run.print_info()

In [None]:
run.subsample(point_step=4)
run.print_info()

In [None]:
from helper.visualization import generate_projections, denoise_projections

pca_animation = generate_projections(
    run,
    method='pca',
    pca_fit_basis='all',
    window_size=16,
).denoise(blend=0.8, mode='exponential')

print(f"{len(run.embeddings)} - 1 should be {len(run.embedding_drifts[1])}")

In [None]:
from helper.visualization import generate_projections, show_animations
from helper.plots import show_multiple_projections_with_slider

tsne_animation = generate_projections(
    run,
    method='tsne',
)

In [None]:
%matplotlib ipympl
%matplotlib widget

show_animations(
    [tsne_animation, pca_animation],
    interpolate=True,
    steps_per_transition=5
)

In [None]:
tsne_animation.evaluate()

In [None]:
from helper.visualization import animate_projections

def save_as_gif(
        animation,
        frame_interval=50,
        figsize=(4, 4),
        dot_size=5,
        alpha=0.6,
        cmap='tab10',
        axis_lim=None,
        interpolate=True,
        steps_per_transition=1,
):
    print("Generating plot...")
    ani = animate_projections(
        animation.projections,
        animation.labels,
        frame_interval=frame_interval,
        interpolate=interpolate,
        steps_per_transition=steps_per_transition,
        figsize=figsize,
        dot_size=dot_size,
        alpha=alpha,
        cmap=cmap,
        axis_lim=axis_lim
    )

    print("Saving file...")
    filename = f"plots/animations/{animation.run_id}_{animation.title}.gif"

    ani.save(filename, writer='pillow', dpi=150)
    plt.close(ani._fig)

    print(filename)

In [None]:
save_as_gif(tsne_animation, frame_interval=100, steps_per_transition=5)

In [None]:
denoised_tsne = tsne_animation.denoise(blend=0.8, mode='exponential')

In [None]:
show_animations([tsne_animation, denoised_tsne, pca_animation])

In [None]:
denoised_tsne.evaluate()

In [None]:
from helper.visualization import show_projections_and_drift

show_animations([pca_animation, denoised_tsne, tsne_animation], with_drift=True)

## t-SNE with update blending

After computing t-SNE for frame $i$, you blend it with the previous projection.
`tsne_update` is a weight in $[0,1]$:
- If `tsne_update=1`, it's exactly the original method (no blending).
- If `tsne_update=0`, you freeze to the previous frame (no update at all).
- Values like `0.2 - 3` should result in a smooth interpolation between old and new.

Result: Extra smoothing over time.
- Avoids too fast movements
- Avoids flips of dense clusters
- Can introduce lag or “stickiness,” but animations look steadier.

In [None]:
tsne_avg = generate_projections(
    run,
    method='tsne',
    tsne_update=0.3
)

In [None]:
show_animations(
    [tsne_avg, tsne_animation],
    interpolate=True,
    steps_per_transition=5,
)

## t-SNE with backwards computation
Starts with the last frame
- Ensures convergence
- Improves noise in early frames as the basis is from a later and better picture

In [None]:
tsne_reverse = generate_projections(
    run,
    method='tsne',
    reverse_computation=True
)

In [None]:
show_animations(
    [tsne_reverse, tsne_animation],
    interpolate=True,
    steps_per_transition=5
)

In [None]:
tsne_reverse_2 = generate_projections(
    run,
    method='tsne',
    reverse_computation=True,
    tsne_update=0.2
)

In [None]:
show_animations(
    [tsne_reverse_2, tsne_reverse, tsne_animation],
    interpolate=True,
    steps_per_transition=5,
    with_drift=True
)

## t-SNE with Cosine metric

In [None]:
tsne_cosine = generate_projections(
    run,
    method='tsne',
    metric='cosine'
)

In [None]:
show_projections_and_drift(
    projections_list = [tsne_cosine, tsne_animation],
    interpolate=True,
    steps_per_transition=5,
)

## t-SNE with random Seed

In [None]:
tsne_random = generate_projections(
    run,
    method='tsne',
    random_state=1106,
    tsne_init='random'
)

In [None]:
show_animations(
    [tsne_random, tsne_animation],
    interpolate=True,
    steps_per_transition=5
)

## t-SNE Perplexity
Compare 5 - 10 - 30 - 50

In [None]:
tsne_p_5 = generate_projections(
    run,
    method='tsne',
    tsne_perplexity=5
)

tsne_p_10 = generate_projections(
    run,
    method='tsne',
    tsne_perplexity=10
)

tsne_p_50 = generate_projections(
    run,
    method='tsne',
    tsne_perplexity=50
)

tsne_p_30 = tsne_animation

In [None]:
show_animations(
    [tsne_p_5, tsne_p_10, tsne_p_30, tsne_p_50],
    custom_titles=["t-SNE Perplexity 5",
                   "t-SNE Perplexity 10",
            "t-SNE Perpl. 30 (standard)",
            "t-SNE Perplexity 50"],
)

## Combinations
Educated guesses for a good t-SNE Visualization

In [None]:
tsne_p_5_blend = generate_projections(
    run,
    method='tsne',
    tsne_perplexity=5,
    tsne_update=0.2
)
tsne_p_5_blend_3 = generate_projections(
    run,
    method='tsne',
    tsne_perplexity=5,
    tsne_update=0.3
)

In [None]:
show_animations(
    [
        tsne_p_5_blend,
        tsne_p_5_blend_3,
        tsne_p_5,
        tsne_avg,
        tsne_animation
    ],
    custom_titles=[
        "t-SNE P5 Blending 0.2",
        "t-SNE P5 Blending 0.3",
        "t-SNE Perplexity 5",
        "t-SNE Blending",
        "t-SNE"
    ],
)

# td-SNE

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
import time


class DynamicTSNE:
    def __init__(
            self,
            output_dims=2,
            verbose=True,
    ):
        self.output_dims = output_dims
        self.verbose = verbose
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def compute_affinities(self, Xs, perplexity=30.0, k_neighbors=90):
        def Hbeta(D, beta):
            P = torch.exp(-D * beta)
            sumP = torch.sum(P)
            sumP = torch.clamp(sumP, min=1e-8)
            H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
            P = torch.clamp(P / sumP, min=1e-8)
            return H, P

        def compute_P(X, init_beta=None):
            t0 = time.time()

            n = X.shape[0]
            D = torch.cdist(X, X, p=2).pow(2)
            P = torch.zeros((n, n), device=X.device)
            beta = init_beta.clone() if init_beta is not None else torch.ones(n, device=X.device)
            logU = torch.log(torch.tensor(perplexity, device=X.device))
            all_tries = 0

            for i in range(n):
                distances = D[i]
                topk = torch.topk(distances, k=k_neighbors + 1, largest=False)
                idx = topk.indices[topk.indices != i][:k_neighbors]
                Di = torch.clamp(distances[idx], max=1e3)

                betamin, betamax = None, None
                H, thisP = Hbeta(Di, beta[i])
                tries = 0
                while torch.abs(H - logU) > 1e-5 and tries < 50:
                    if H > logU:
                        betamin = beta[i].clone()
                        beta[i] = beta[i] * 2 if betamax is None else (beta[i] + betamax) / 2
                    else:
                        betamax = beta[i].clone()
                        beta[i] = beta[i] / 2 if betamin is None else (beta[i] + betamin) / 2
                    H, thisP = Hbeta(Di, beta[i])
                    tries += 1
                all_tries += tries
                P[i, idx] = thisP

            if self.verbose:
                print(f"Total affinity computation time: {time.time() - t0:.2f}s, {all_tries / n} Tries")

            P = (P + P.T) / (2 * n)
            return P, beta

        X_tensor = [torch.tensor(X, device=self.device) for X in Xs]
        self.Xs = X_tensor

        Ps = []
        prev_beta = None
        for X in X_tensor:
            P, prev_beta = compute_P(X, prev_beta)
            Ps.append(P)

        self.Ps = torch.stack(Ps)
        assert not torch.isnan(self.Ps).any(), "Affinity matrix has NaN"

    def fit(self, n_epochs=1000, exaggeration=12.0, exaggeration_epochs=250, lr=200.0, lambd=0.1):
        T = len(self.Xs)
        n = self.Xs[0].shape[0]

        Y_init = []
        for X in self.Xs:
            X_cpu = X.detach().cpu().numpy()
            pca = PCA(n_components=self.output_dims)
            Y_pca = pca.fit_transform(X_cpu)
            Y_init.append(torch.tensor(Y_pca, device=self.device, dtype=torch.float32))
        
        Y = torch.stack(Y_init)
        Y.requires_grad_()

        optimizer = Adam([Y], lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

        for epoch in range(n_epochs):
            optimizer.zero_grad()
            total_loss = 0

            if epoch < exaggeration_epochs:
                P_use = self.Ps * exaggeration
            else:
                lambd = 0
                P_use = self.Ps
            
            for t in range(T):
                Qt, _ = self._compute_lowdim_affinities(Y[t])
                loss = self._kl_divergence(P_use[t], Qt)
                if t > 0:
                    loss += lambd * F.mse_loss(Y[t], Y[t - 1])
                total_loss += loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_([Y], max_norm=10.0)
            optimizer.step()
            scheduler.step()

            if self.verbose and (epoch % 100 == 0 or epoch == n_epochs - 1):
                print(f"Epoch {epoch}, Loss: {total_loss.item():.4f}")

        return [Y[t].detach().cpu().numpy() for t in range(T)]

    def _compute_lowdim_affinities(self, Y):
        num = 1 / (1 + torch.cdist(Y, Y, p=2).pow(2))
        num.fill_diagonal_(0.0)
        Q = torch.clamp(num / num.sum(), min=1e-5)
        return Q, num

    def _kl_divergence(self, P, Q):
        return torch.sum(P * torch.log((P + 1e-8) / (Q + 1e-8)))

In [None]:
subset = [embedding_list[i] for i in range(10, len(embedding_list), 30)]
len(subset)
#samples_per_class = 100
#classes = 10  # assuming 1000 samples total and 100 per class
#indices = np.concatenate([np.arange(c * 100, c * 100 + samples_per_class) for c in range(classes)])
#subset = [emb[:][indices] for emb in subset]
#label_subset = [emb[:][indices] for emb in results["subset_labels"]]

In [None]:
tsne = DynamicTSNE()
tsne.compute_affinities(subset, perplexity=30.0, k_neighbors=250)

In [None]:
from sklearn.decomposition import PCA
projections = tsne.fit(lr=200, lambd=0.1, n_epochs=1000, exaggeration_epochs=250, exaggeration=22.0)

In [None]:
%matplotlib widget

show_with_slider(
    projections,
    labels=results["subset_labels"],
    interpolate=False,
    steps_per_transition=4,
)

In [None]:
matplotlib.pyplot.close()

In [None]:
visualization_drift_vs_embedding_drift(thesne, embedding_drift_subset)

In [None]:
Y = tsne_with_live_callback(embedding_list[10],
                        labels=results["subset_labels"][0],
                        perplexity=30,
                        lr=200,
                        n_iter=10000,
                        interval=5)

In [None]:
from openTSNE import TSNE
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, clear_output

def tsne_with_live_callback(X, labels=None, perplexity=30, lr=200, n_iter=1000, interval=50):
    fig, ax = plt.subplots(figsize=(6, 6))

    def callback(iteration, error, Y):
        if iteration % interval == 0 or iteration == n_iter - 1:
            ax.clear()
            if labels is not None:
                ax.scatter(Y[:, 0], Y[:, 1], c=labels, cmap='tab10', s=5, alpha=0.7)
            else:
                ax.scatter(Y[:, 0], Y[:, 1], s=5, alpha=0.7)
            ax.set_title(f"t-SNE at iter {iteration}")
            ax.set_xticks([])
            ax.set_yticks([])
            clear_output(wait=True)
            display(fig)

    tsne = TSNE(
        n_components=2,
        perplexity=perplexity,
        learning_rate=lr,
        n_iter=n_iter,
        initialization="pca",
        callbacks=callback,
        callbacks_every_iters=interval,
        verbose=False,
    )

    try:
        Y = tsne.fit(X)
    except KeyboardInterrupt:
        print("Interrupted, returning current state.")
        return tsne

    return Y

# OPENTSNE Dynamic

In [None]:
import numpy as np
from openTSNE import TSNE
from sklearn.decomposition import PCA


class DynamicTSNE_2:
    def __init__(self, perplexity=30, n_iter=1000, init='pca', random_state=None):
        """
        init: 'pca', 'random', or 'previous'
        """
        assert init in ['pca', 'random', 'previous']
        self.perplexity = perplexity
        self.n_iter = n_iter
        self.init = init
        self.random_state = random_state

    def fit_transform(self, Xs):
        """
        Xs: List of np.ndarray (each shape: [n_samples, n_features])
        Returns: List of np.ndarray (each shape: [n_samples, 2])
        """
        embeddings = []
        previous_embedding = None

        for i, X in enumerate(Xs):
            print(i)
            if i == 0 or self.init == 'pca':
                init_embedding = PCA(n_components=2).fit_transform(X) if self.init != 'random' else 'random'
            else:
                init_embedding = previous_embedding

            tsne = TSNE(
                n_jobs=-1,
                perplexity=self.perplexity,
                n_iter=self.n_iter,
                initialization=init_embedding,
                random_state=self.random_state,
                verbose=True
            )
            embedding = tsne.fit(X)
            embeddings.append(embedding)
            previous_embedding = embedding

        return embeddings

In [None]:
dynamic_tsne = DynamicTSNE_2(n_iter=500, init='previous', random_state=42)
projections_2 = dynamic_tsne.fit_transform(subset)

In [None]:
%matplotlib widget

show_with_slider(
    projections_2,
    labels=results["subset_labels"],
    interpolate=False,
    steps_per_transition=4,
)

# MODERN DYNAMIC TNSNE

In [None]:
import numpy as np

subset = [embedding_list[i] for i in range(0, len(embedding_list), 5)]
samples_per_class = 4
classes = 100  # assuming 1000 samples total and 100 per class
indices = np.concatenate([np.arange(c * 10, c * 10 + samples_per_class) for c in range(classes)])
subset = [emb[:][indices] for emb in subset]
label_subset = [emb[:][indices] for emb in results["subset_labels"]]
len(subset)

In [None]:
label_subset[0]

In [None]:
modern_dynamic_tsne = ModernDynamicTSNE(
    n_epochs=500,
    perplexity=30,
)
projections_3 = modern_dynamic_tsne.fit_transform(subset)

In [None]:
projections_tsne = generate_projections(
    embeddings_list=subset,
    method='tsne',
)

In [None]:
modern_dynamic_tsne

In [None]:
show_multiple_projections_with_slider(
    projections_list=[projections_tsne, projections_3],
    labels=results["subset_labels"],
    titles=["t-SNE", "Dynamic t-SNE"],
    interpolate=False,
    figsize_per_plot=(4, 4),
    dataset=dataset
)

In [None]:
show_with_slider(
    projections_3,
    labels=label_subset,
    interpolate=False,
    steps_per_transition=4,
    dataset=dataset,
    alpha=0.7
)

In [None]:
import numpy as np
from sklearn.decomposition import PCA
import torch
import torch.nn.functional as F
from torch.optim import SGD


class ModernDynamicTSNE:
    def __init__(
        self,
        perplexity=30,
        n_epochs=1000,
        output_dims=2,
        initial_lr=2400,
        final_lr=200,
        lr_switch=250,
        init_stdev=1e-4,
        initial_momentum=0.5,
        final_momentum=0.8,
        momentum_switch=250,
        lmbda=0.0,
        sigma_iters=50,
        verbose=True,
        device=None
    ):
        self.perplexity = perplexity
        self.n_epochs = n_epochs
        self.output_dims = output_dims
        self.initial_lr = initial_lr
        self.final_lr = final_lr
        self.lr_switch = lr_switch
        self.init_stdev = init_stdev
        self.initial_momentum = initial_momentum
        self.final_momentum = final_momentum
        self.momentum_switch = momentum_switch
        self.lmbda = lmbda
        self.sigma_iters = sigma_iters
        self.verbose = verbose
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    def _hbeta(self, D, beta):
        P = torch.exp(-D * beta)
        sumP = torch.sum(P)
        sumP = torch.clamp(sumP, min=1e-8)
        H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
        P = P / sumP
        return H, P

    def _binary_search_perplexity(self, D, tol=1e-5):
        n = D.shape[0]
        sigmas = torch.ones(n, device=self.device)
        P = torch.zeros((n, n), device=self.device)

        logU = np.log(self.perplexity)
        for i in range(n):
            betamin = None
            betamax = None
            beta = sigmas[i]
            Di = D[i][torch.arange(n) != i]
            H, thisP = self._hbeta(Di, beta)

            tries = 0
            while torch.abs(H - logU) > tol and tries < self.sigma_iters:
                if H > logU:
                    betamin = beta
                    beta = beta * 2 if betamax is None else (beta + betamax) / 2
                else:
                    betamax = beta
                    beta = beta / 2 if betamin is None else (beta + betamin) / 2
                H, thisP = self._hbeta(Di, beta)
                tries += 1
            P[i, torch.arange(n) != i] = thisP
        return (P + P.T) / (2 * n)

    def _precompute_Ps(self, Xs):
        Ps = []
        for X in Xs:
            D = torch.cdist(X, X).pow(2)
            P = self._binary_search_perplexity(D)
            Ps.append(P)
        return Ps

    def _compute_cost(self, Ys, Ps):
        total_kl = 0
        for Y, P in zip(Ys, Ps):
            Q_num = 1 / (1 + torch.cdist(Y, Y).pow(2))
            Q_num.fill_diagonal_(0)
            Q = Q_num / Q_num.sum()
            kl = torch.sum(P * torch.log((P + 1e-8) / (Q + 1e-8)))
            total_kl += kl
        smoothness = sum((Ys[i] - Ys[i + 1]).pow(2).sum() for i in range(len(Ys) - 1))
        return total_kl + self.lmbda * smoothness / (2 * Ys[0].shape[0])

    def fit_transform(self, Xs_np):
        Xs = [torch.tensor(X, device=self.device, dtype=torch.float32) for X in Xs_np]
        T = len(Xs)
        N = Xs[0].shape[0]

        # Init Ys with PCA
        Ys = [
            torch.tensor(PCA(n_components=self.output_dims).fit_transform(X.cpu().numpy()),
                         device=self.device, dtype=torch.float32, requires_grad=True)
            for X in Xs
        ]

        # Precompute all P matrices once
        Ps = self._precompute_Ps(Xs)

        optimizer = SGD(Ys, lr=self.initial_lr, momentum=self.initial_momentum)

        for epoch in range(self.n_epochs):
            if epoch == self.lr_switch:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = self.final_lr
            if epoch == self.momentum_switch:
                for param_group in optimizer.param_groups:
                    param_group['momentum'] = self.final_momentum

            optimizer.zero_grad()
            loss = self._compute_cost(Ys, Ps)
            loss.backward()
            optimizer.step()

            if self.verbose and (epoch % 100 == 0 or epoch == self.n_epochs - 1):
                print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

        return [Y.detach().cpu().numpy() for Y in Ys]

# PHATE

In [None]:
import tphate
import numpy as np

def generate_tphate_projection(
    embeddings,
    out_dim=2,
    knn=5,
    decay=40,
    t=2,
    n_pca=None,
    random_state=42,
    **tphate_kwargs
):
    """
    Computes T-PHATE projection for 2D embedding array.
    Args:
        embeddings: numpy array, shape (n_points, embedding_dim)
    Returns:
        projection: np.ndarray, shape (n_points, out_dim)
    """
    tphate_operator = tphate.TPHATE(
        n_components=out_dim,
        knn=knn,
        decay=decay,
        t=t,
        n_pca=n_pca,
        random_state=random_state,
        **tphate_kwargs
    )
    projection = tphate_operator.fit_transform(embeddings)
    return projection


In [None]:
projections_tphate = generate_tphate_projection(
    embeddings=np.array(results["subset_embeddings"]),
    t=2
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# projections_tphate_new: shape (n_epochs, 2)
epochs = np.arange(projections_tphate_new.shape[0])

plt.figure(figsize=(6, 6))
plt.scatter(
    projections_tphate_new[:, 0],
    projections_tphate_new[:, 1],
    c=epochs,
    cmap='viridis',
    s=30,
    alpha=0.8
)
plt.xlabel('T-PHATE 1')
plt.ylabel('T-PHATE 2')
plt.title('T-PHATE Projection of Embedding Evolution')
plt.colorbar(label='Epoch')
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# projections_tphate: (n_epochs, n_samples, 2)
# Choose a sample index to track
sample_idx = 0  # or any index you want to visualize
traj = projections_tphate[:, sample_idx, :]  # shape: (n_epochs, 2)
epochs = np.arange(traj.shape[0])

plt.figure(figsize=(6, 6))
plt.plot(traj[:, 0], traj[:, 1], '-o', c='blue', alpha=0.8, label=f'Sample {sample_idx}')
sc = plt.scatter(traj[:, 0], traj[:, 1], c=epochs, cmap='viridis', s=60, edgecolor='k')
plt.xlabel('T-PHATE 1')
plt.ylabel('T-PHATE 2')
plt.title(f'Trajectory of Sample {sample_idx} in T-PHATE Space')
plt.colorbar(sc, label='Epoch')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# projections_tphate: (n_epochs, n_samples, 2)
# labels: (n_samples,) — constant for all epochs (assumed), or (n_epochs, n_samples)

labels = np.array(results["subset_labels"])
unique_classes = np.unique(labels)
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_classes)))

plt.figure(figsize=(7, 7))
for i, cls in enumerate(unique_classes):
    class_mask = labels == cls
    # Compute mean position for this class at each epoch
    class_traj = np.stack([
        projections_tphate[epoch, class_mask, :].mean(axis=0)
        for epoch in range(projections_tphate.shape[0])
    ])
    # Plot trajectory and add arrows for direction
    plt.plot(class_traj[:, 0], class_traj[:, 1], '-o', color=colors[i], label=f'Class {cls}', alpha=0.85)
    # Draw arrows to show direction (from epoch to epoch)
    for j in range(1, len(class_traj)):
        plt.arrow(class_traj[j-1, 0], class_traj[j-1, 1],
                  class_traj[j, 0] - class_traj[j-1, 0],
                  class_traj[j, 1] - class_traj[j-1, 1],
                  head_width=0.05, head_length=0.07, color=colors[i], alpha=0.6, length_includes_head=True)

plt.xlabel('T-PHATE 1')
plt.ylabel('T-PHATE 2')
plt.title('Class Mean Trajectories in T-PHATE Space')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
labels = np.array(results["subset_labels"])  # shape should be (n_samples,)
print(labels.shape)

In [None]:
import numpy as np

embeddings=np.array(results["subset_embeddings"])
n_epochs, n_samples, embedding_dim = embeddings.shape
labels = np.array(results["subset_labels"])  # shape (n_epochs, n_samples)
unique_classes = np.unique(labels)
n_classes = len(unique_classes)

class_means = np.zeros((n_epochs, n_classes, embedding_dim))

for epoch in range(n_epochs):
    labels_epoch = labels[epoch]  # shape (n_samples,)
    for i, cls in enumerate(unique_classes):
        mask = labels_epoch == cls  # shape (n_samples,)
        class_means[epoch, i, :] = embeddings[epoch, mask, :].mean(axis=0)

In [None]:
flat_class_means = class_means.reshape(-1, embedding_dim)  # shape: (n_epochs * n_classes, embedding_dim)

projections_class_means = generate_tphate_projection(
    embeddings=flat_class_means,
    t=2
)  # shape: (n_epochs * n_classes, 2)

# Reshape back for visualization: (n_epochs, n_classes, 2)
projections_class_means = projections_class_means.reshape(n_epochs, n_classes, -1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# projections_class_means: shape (n_epochs, n_classes, 2)
# unique_classes: array of class labels
# n_classes: number of classes

if n_classes <= 10:
    cmap = plt.cm.tab10
else:
    cmap = plt.cm.tab20

colors = cmap(np.linspace(0, 1, n_classes))
epochs = np.arange(projections_class_means.shape[0])

plt.figure(figsize=(9, 8))
for i, cls in enumerate(unique_classes):
    traj = projections_class_means[:, i, :]
    # Plot class trajectory
    plt.plot(traj[:, 0], traj[:, 1], '-', color=colors[i], alpha=0.85)
    # Scatter points colored by epoch (temporal gradient)
    sc = plt.scatter(traj[:, 0], traj[:, 1], c=epochs, cmap='viridis', s=60, label=f'Class {cls}', edgecolor='k', zorder=3)
    # Draw arrows for direction (skip very short arrows for clarity)
    for j in range(1, len(traj)):
        dx = traj[j, 0] - traj[j-1, 0]
        dy = traj[j, 1] - traj[j-1, 1]
        if np.hypot(dx, dy) > 1e-4:  # skip if too small
            plt.arrow(traj[j-1, 0], traj[j-1, 1], dx, dy,
                      head_width=0.04, head_length=0.07, color=colors[i], alpha=0.6, length_includes_head=True)
    break

plt.xlabel('T-PHATE 1')
plt.ylabel('T-PHATE 2')
plt.title('Class Mean Trajectories in T-PHATE Space')
plt.colorbar(sc, label='Epoch')
plt.legend(title="Class", loc="best")
plt.tight_layout()
plt.show()