# Embedding Visualization

## Import Data

In [None]:
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np

In [None]:
# ==== MNIST ========
# dataset = "mnist"

#run = "mnist_MLP_32_0.9595"
# run = "mnist_MLP_128_0.9719"
#run = "mnist_CNN_64_0.9901"
#run = "mnist_ViT_64_0.9724"


# ==== CIFAR 10 ========
# dataset = "cifar10"
#run = "cifar10_MLP_128_0.5672"

#run = "cifar10_CNN_128_0.8977"

#run = "cifar10_ViT_128_0.6311"
#run = "cifar10_ViT_128_0.6530_noisy"

# ==== CIFAR 100 ========
dataset = "cifar100"
#run = "cifar100_CNN_256_0.6060"
# run = "cifar100_CNN_256_0.6710"

#run = "cifar100_ViT_192_0.4760"
run = "cifar100_ViT_192_0.5429"

In [None]:
from data_manager import load_training_data

results = load_training_data(run)
results["embedding_drifts"] = {int(k): results["embedding_drifts"][k] for k in sorted(results["embedding_drifts"].keys(), key=int)}

## Visualize Training

In [None]:
from train_viz import _plot_loss_accuracy, _plot_gradients, _plot_embedding_drift, _plot_scheduled_lr

fig, axs = plt.subplots(2, 2, figsize=(10, 8))
epochs = len(results["train_losses"])
_plot_loss_accuracy(axs[0][0], epochs-1, epochs, results["train_losses"], results["val_losses"], results["train_accuracies"], results["val_accuracies"])
_plot_gradients(axs[0][1], range(0, len(results["gradient_norms"])),results["gradient_norms"], results["max_gradients"], results["grad_param_ratios"], 20)
if "scheduler_history" in results.keys():
    _plot_scheduled_lr(axs[1][0], results["scheduler_history"])
_plot_embedding_drift(axs[1][1], results["embedding_drifts"])


In [None]:
%matplotlib widget
from train_viz import _plot_embedding_drift

embedding_drifts = results["embedding_drifts"].copy()
fig, axs = plt.subplots(1, 1, figsize=(10, 4))

# Plot 2x Drifts
axs.plot(range(1, len(embedding_drifts[1]) + 1), np.array(embedding_drifts[1]) * 2, color="green", label="2x Drift 1", alpha=0.3)
axs.plot(range(1, len(embedding_drifts[2]) + 1), np.array(embedding_drifts[2]) * 2, color="blue", label="2x Drift 2", alpha=0.3)
axs.plot(range(1, len(embedding_drifts[4]) + 1), np.array(embedding_drifts[4]) * 2, color="orange", label="2x Drift 2", alpha=0.3)
axs.plot(range(1, len(embedding_drifts[8]) + 1), np.array(embedding_drifts[8]) * 2, color="red", label="2x Drift 2", alpha=0.3)
_plot_embedding_drift(axs, embedding_drifts, max_multiply=1.5)

plt.legend()
plt.show()

# Visualizations PCA

In [None]:
from train_viz import generate_projections, animate_projections, show_with_slider, show_multiple_projections_with_slider, visualization_drift_vs_embedding_drift, denoise_projections

In [None]:
projections_pca_first = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='first',
)
projections_pca_last = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='last',
)
projections_pca_all = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='all',
)

In [None]:
projections_pca_window = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='window',
    window_size=5,
)

In [None]:
# projections_filtered, labels_filtered = filter_classes(projections_pca_window, results["subset_labels"], [4, 30, 55, 72, 95])

In [None]:
from train_viz import show_cifar100_legend

if dataset == "cifar100":
    show_cifar100_legend(cmap = "tab20")

In [None]:
show_with_slider(
    projections_pca_window,
    labels=results["subset_labels"],
    interpolate=True,
    steps_per_transition=3,
    alpha=1,
    dataset=dataset,
    show_legend=False if dataset == "cifar100" else True,
)

In [None]:
projections_pca_first = denoise_projections(projections_pca_first, window_size=15, blend=0.9, mode='window')
projections_pca_last = denoise_projections(projections_pca_last, window_size=15, blend=0.9, mode='window')
projections_pca_all = denoise_projections(projections_pca_all, window_size=15, blend=0.9, mode='window')
projections_pca_window = denoise_projections(projections_pca_window, window_size=15, blend=0.9, mode='window')

In [None]:
show_multiple_projections_with_slider(
    projections_list=[projections_pca_first, projections_pca_last, projections_pca_all, projections_pca_window],
    labels=results["subset_labels"],
    titles=["PCA on first", "PCA on last", "PCA on all", "PCA sliding window"],
    interpolate=False,
    steps_per_transition=2,
    figsize_per_plot=(4, 4),
    dataset=dataset,
    alpha=0.8,
    dot_size=12
)

In [None]:
visualization_drift_vs_embedding_drift(projections_pca_window, embedding_drifts)

In [None]:
import matplotlib
matplotlib.pyplot.close()

# 3D

In [None]:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import ipywidgets as widgets
from IPython.display import display
from vision_classification import get_text_labels
from train_viz import _interpolate_projections, _prepare_cifar100_plot_config, _create_cifar100_legend

def _prepare_cifar100_plot_config(class_names):
    from vision_classification import get_cifar100_coarse_to_fine_labels, get_cifar100_fine_to_coarse_labels
    marker_styles = ['o', 's', '^', 'v', '<', '>', 'P', '*', 'X', 'D']

    coarse_to_fine = get_cifar100_coarse_to_fine_labels()
    fine_to_coarse = get_cifar100_fine_to_coarse_labels()
    coarse_names = list(coarse_to_fine.keys())
    cmap = plt.colormaps.get_cmap('gist_rainbow').resampled(len(coarse_names))
    coarse_color_map = {coarse: cmap(i) for i, coarse in enumerate(coarse_names)}
    fine_name_to_index = {name: i for i, name in enumerate(class_names)}

    plot_config = {}
    for coarse in coarse_names:
        fine_list = coarse_to_fine[coarse]
        color = coarse_color_map[coarse]
        for j, fine in enumerate(fine_list):
            fine_idx = fine_name_to_index[fine]
            marker = marker_styles[j % len(marker_styles)]
            plot_config[fine_idx] = {'color': color, 'marker': marker, 'coarse': coarse}
    return plot_config, fine_to_coarse

def show_with_slider_3d(
    projections,
    labels,
    dot_size=5,
    alpha=0.6,
    interpolate=False,
    steps_per_transition=10,
    dataset=None,
    show_legend=False  # placeholder for symmetry
):
    class_names = range(0, 100) if dataset is None else get_text_labels(dataset)
    projections = np.array(projections)
    projections = _interpolate_projections(projections, steps_per_transition) if interpolate else projections

    if dataset == "cifar100":
        fine_index_to_plot_config, fine_to_coarse = _prepare_cifar100_plot_config(class_names)

    unique_labels = np.unique(np.concatenate(labels))
    label_frame = labels[0]  # assumed constant

    # Set up 3D plot
    fig = plt.figure(figsize=(7, 7))
    ax = fig.add_subplot(111, projection='3d')

    all_proj = np.concatenate(projections, axis=0)
    max_abs = np.max(np.abs(all_proj))
    ax.set_xlim3d(-max_abs, max_abs)
    ax.set_ylim3d(-max_abs, max_abs)
    ax.set_zlim3d(-max_abs, max_abs)
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])

    # Create initial scatter per fine class
    scatter_dict = {}
    if dataset == "cifar100":
        for fine_idx in unique_labels:
            idxs = label_frame == fine_idx
            config = fine_index_to_plot_config.get(fine_idx, {})
            color = config.get('color', 'gray')
            marker = config.get('marker', 'o')
            sc = ax.scatter(projections[0][idxs, 0],
                            projections[0][idxs, 1],
                            projections[0][idxs, 2],
                            c=[color], marker=marker,
                            alpha=alpha, edgecolors='none', s=dot_size * 2)
            scatter_dict[fine_idx] = sc
    else:
        sc = ax.scatter(projections[0][:, 0],
                        projections[0][:, 1],
                        projections[0][:, 2],
                        c=label_frame, cmap='tab10', alpha=alpha, s=dot_size)

    # Update function
    def update(frame_idx, azim_angle, elev_angle, auto_rotate):
        if dataset == "cifar100":
            for fine_idx in unique_labels:
                idxs = label_frame == fine_idx
                scatter_dict[fine_idx]._offsets3d = (
                    projections[frame_idx][idxs, 0],
                    projections[frame_idx][idxs, 1],
                    projections[frame_idx][idxs, 2]
                )
        else:
            sc._offsets3d = (
                projections[frame_idx][:, 0],
                projections[frame_idx][:, 1],
                projections[frame_idx][:, 2]
            )
            sc.set_array(np.array(label_frame))
    
        # Only update view if auto-rotate is enabled
        if auto_rotate:
            ax.view_init(elev=elev_angle, azim=azim_angle)
    
        fig.canvas.draw_idle()

    def on_azim_change(change):
        toggle_auto_rotate.value = True

    # Sliders
    slider_frame = widgets.IntSlider(min=0, max=len(projections) - 1, step=1, description="Frame")
    slider_azim = widgets.IntSlider(min=0, max=360, step=1, description="Rotation")
    slider_azim.observe(on_azim_change, names='value')
    toggle_auto_rotate = widgets.Checkbox(value=True, description="Auto-Rotate")

    # Play controls
    play_frame = widgets.Play(interval=150/steps_per_transition, value=0, min=0, max=len(projections) - 1, step=1, description="▶")
    play_azim = widgets.Play(interval=100, value=0, min=0, max=360, step=2, description="↻")
    play_azim.observe(on_azim_change, names='value')
    play_azim.loop = True

    slider_elev = widgets.IntSlider(min=-90, max=90, step=1, value=30, description="Tilt")
    play_elev = widgets.Play(interval=100, min=-90, max=90, step=1, description="↕")
    play_elev.loop = True
    widgets.jslink((play_elev, 'value'), (slider_elev, 'value'))
    
    # Link sliders to play widgets
    widgets.jslink((play_elev, 'value'), (slider_elev, 'value'))
    widgets.jslink((play_frame, 'value'), (slider_frame, 'value'))
    widgets.jslink((play_azim, 'value'), (slider_azim, 'value'))

    out = widgets.interactive_output(update, {
        'frame_idx': slider_frame,
        'azim_angle': slider_azim,
        'elev_angle': slider_elev,
        'auto_rotate': toggle_auto_rotate
    })
    
    display(widgets.VBox([
        widgets.HBox([play_frame, slider_frame]),
        widgets.HBox([play_azim, slider_azim]),
        widgets.HBox([play_elev, slider_elev]),
        toggle_auto_rotate,
        out
    ]))


In [None]:
projections_3d = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='window',
    out_dim=3 #3D
)

In [None]:
projections_3d = denoise_projections(projections_3d, window_size=15, blend=0.9, mode='window')

In [None]:
show_with_slider_3d(
    projections_3d,
    labels=results["subset_labels"],
    interpolate=False,
    steps_per_transition=1,
    alpha=0.7,
    dataset=dataset,
    show_legend=False,
    dot_size=20,
)

In [None]:
denoised_embeddings = denoise_projections(results["subset_embeddings"], window_size=15, blend=0.9, mode='window')
visualization_drift_vs_embedding_drift(projections_3d, denoised_embeddings, embeddings=True)

# Denoising

In [None]:
projections = projections_pca_window
#projections = projections_pca_all

In [None]:
visualization_drift_vs_embedding_drift(projections, embedding_drifts, verbose=True)

In [None]:
denoised_window = denoise_projections(projections, window_size=15, blend=0.9, mode='window')
denoised_exponential = denoise_projections(projections, blend=0.8, mode='exponential')

In [None]:
show_multiple_projections_with_slider(
    projections_list=[projections, denoised_window, denoised_exponential],
    labels=results["subset_labels"],
    titles=["PCA", "denoised window", "denoised exponential"],
    interpolate=False,
    figsize_per_plot=(4, 4),
    dataset=dataset
)

In [None]:
denoised_embeddings = denoise_projections(results["subset_embeddings"], window_size=15, blend=0.9, mode='window')
visualization_drift_vs_embedding_drift(denoised_window, denoised_embeddings, verbose=True, embeddings=True)

In [None]:
denoised_exponential = denoise_projections(projections, blend=0.8, mode='exponential')
denoised_embeddings = denoise_projections(results["subset_embeddings"], blend=0.8, mode='exponential')

visualization_drift_vs_embedding_drift(denoised_exponential, denoised_embeddings, verbose=True, embeddings=True)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define parameters
window_sizes = [1, 2, 4, 6, 8, 10, 15]
blend_values = np.linspace(0, 1, 11)
correlation_results = {ws: [] for ws in window_sizes}
exponentials = []

# Run correlations
for blend in blend_values:
    for ws in window_sizes:
        denoised = denoise_projections(projections, window_size=ws, blend=blend, mode='window')
        denoised_embeddings = denoise_projections(results["subset_embeddings"], window_size=ws, blend=blend, mode='window')
        corr = visualization_drift_vs_embedding_drift(denoised, denoised_embeddings, verbose=False, embeddings=True)
        correlation_results[ws].append(corr)
        
    denoised = denoise_projections(projections, blend=blend, mode='exponential')
    denoised_embeddings = denoise_projections(results["subset_embeddings"], blend=blend, mode='exponential')
    corr = visualization_drift_vs_embedding_drift(denoised, denoised_embeddings, verbose=False, embeddings=True)
    exponentials.append(corr)

In [None]:
# Plotting
plt.figure(figsize=(8, 4))
for ws in window_sizes:
    plt.plot(blend_values, correlation_results[ws], label=f'window_size={ws}')
plt.plot(blend_values, exponentials, label=f'exponential', linewidth=3)
plt.xlabel("Blend")
plt.ylabel("Correlation")
plt.title("Correlation vs. Blend for different denoise calculations")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Define parameters
window_sizes = range(1, 20)
correlation_results = []
blend = 0.9

# Run correlations
for ws in window_sizes:
    denoised = denoise_projections(projections, window_size=ws, blend=blend, mode='window')
    denoised_embeddings = denoise_projections(results["subset_embeddings"], window_size=ws, blend=blend, mode='window')
    corr = visualization_drift_vs_embedding_drift(denoised, denoised_embeddings, verbose=False, embeddings=True)
    correlation_results.append(corr)

# Plotting
plt.figure(figsize=(8, 4))
plt.plot(window_sizes, correlation_results)
plt.xlabel("Window Size")
plt.ylabel("Correlation")
plt.title("Correlation vs. Window Size")
plt.grid(True)
plt.show()

# t-SNE Visualization

In [None]:
projections_tsne = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='tsne',
)

In [None]:
show_multiple_projections_with_slider(
    projections_list=[denoised_exponential, projections_tsne],
    labels=results["subset_labels"],
    titles=["PCA", "t-SNE"],
    interpolate=False,
    steps_per_transition=2,
    figsize_per_plot=(5, 5),
    shared_axes=False,
)

In [None]:
visualization_drift_vs_embedding_drift(projections_tsne, embedding_drifts)

In [None]:
denoised_tsne = denoise_projections(projections_tsne, blend=0.8, mode='exponential')
denoised_embeddings = denoise_projections(results["subset_embeddings"], blend=0.8, mode='exponential')

In [None]:
show_multiple_projections_with_slider(
    projections_list=[denoised_exponential, denoised_tsne, projections_tsne],
    labels=results["subset_labels"],
    titles=["PCA", "t-SNE denoised", "t-SNE"],
    interpolate=False,
    steps_per_transition=1,
    figsize_per_plot=(5, 5),
    shared_axes=False,
)

In [None]:
visualization_drift_vs_embedding_drift(denoised_tsne, denoised_embeddings, precalculated=False)

# td-SNE

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
import time


class DynamicTSNE:
    def __init__(
            self,
            output_dims=2,
            verbose=True,
    ):
        self.output_dims = output_dims
        self.verbose = verbose
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def compute_affinities(self, Xs, perplexity=30.0, k_neighbors=90):
        def Hbeta(D, beta):
            P = torch.exp(-D * beta)
            sumP = torch.sum(P)
            sumP = torch.clamp(sumP, min=1e-8)
            H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
            P = torch.clamp(P / sumP, min=1e-8)
            return H, P

        def compute_P(X, init_beta=None):
            t0 = time.time()

            n = X.shape[0]
            D = torch.cdist(X, X, p=2).pow(2)
            P = torch.zeros((n, n), device=X.device)
            beta = init_beta.clone() if init_beta is not None else torch.ones(n, device=X.device)
            logU = torch.log(torch.tensor(perplexity, device=X.device))
            all_tries = 0

            for i in range(n):
                distances = D[i]
                topk = torch.topk(distances, k=k_neighbors + 1, largest=False)
                idx = topk.indices[topk.indices != i][:k_neighbors]
                Di = torch.clamp(distances[idx], max=1e3)

                betamin, betamax = None, None
                H, thisP = Hbeta(Di, beta[i])
                tries = 0
                while torch.abs(H - logU) > 1e-5 and tries < 50:
                    if H > logU:
                        betamin = beta[i].clone()
                        beta[i] = beta[i] * 2 if betamax is None else (beta[i] + betamax) / 2
                    else:
                        betamax = beta[i].clone()
                        beta[i] = beta[i] / 2 if betamin is None else (beta[i] + betamin) / 2
                    H, thisP = Hbeta(Di, beta[i])
                    tries += 1
                all_tries += tries
                P[i, idx] = thisP

            if self.verbose:
                print(f"Total affinity computation time: {time.time() - t0:.2f}s, {all_tries / n} Tries")

            P = (P + P.T) / (2 * n)
            return P, beta

        X_tensor = [torch.tensor(X, device=self.device) for X in Xs]
        self.Xs = X_tensor

        Ps = []
        prev_beta = None
        for X in X_tensor:
            P, prev_beta = compute_P(X, prev_beta)
            Ps.append(P)

        self.Ps = torch.stack(Ps)
        assert not torch.isnan(self.Ps).any(), "Affinity matrix has NaN"

    def fit(self, n_epochs=1000, exaggeration=12.0, exaggeration_epochs=250, lr=200.0, lambd=0.1):
        T = len(self.Xs)
        n = self.Xs[0].shape[0]

        Y_init = []
        for X in self.Xs:
            X_cpu = X.detach().cpu().numpy()
            pca = PCA(n_components=self.output_dims)
            Y_pca = pca.fit_transform(X_cpu)
            Y_init.append(torch.tensor(Y_pca, device=self.device, dtype=torch.float32))
        
        Y = torch.stack(Y_init)
        Y.requires_grad_()

        optimizer = Adam([Y], lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

        for epoch in range(n_epochs):
            optimizer.zero_grad()
            total_loss = 0

            if epoch < exaggeration_epochs:
                P_use = self.Ps * exaggeration
            else:
                lambd = 0
                P_use = self.Ps
            
            for t in range(T):
                Qt, _ = self._compute_lowdim_affinities(Y[t])
                loss = self._kl_divergence(P_use[t], Qt)
                if t > 0:
                    loss += lambd * F.mse_loss(Y[t], Y[t - 1])
                total_loss += loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_([Y], max_norm=10.0)
            optimizer.step()
            scheduler.step()

            if self.verbose and (epoch % 100 == 0 or epoch == n_epochs - 1):
                print(f"Epoch {epoch}, Loss: {total_loss.item():.4f}")

        return [Y[t].detach().cpu().numpy() for t in range(T)]

    def _compute_lowdim_affinities(self, Y):
        num = 1 / (1 + torch.cdist(Y, Y, p=2).pow(2))
        num.fill_diagonal_(0.0)
        Q = torch.clamp(num / num.sum(), min=1e-5)
        return Q, num

    def _kl_divergence(self, P, Q):
        return torch.sum(P * torch.log((P + 1e-8) / (Q + 1e-8)))

In [None]:
subset = [results["subset_embeddings"][i] for i in range(10, len(results["subset_embeddings"]), 30)]
len(subset)
#samples_per_class = 100
#classes = 10  # assuming 1000 samples total and 100 per class
#indices = np.concatenate([np.arange(c * 100, c * 100 + samples_per_class) for c in range(classes)])
#subset = [emb[:][indices] for emb in subset]
#label_subset = [emb[:][indices] for emb in results["subset_labels"]]

In [None]:
tsne = DynamicTSNE()
tsne.compute_affinities(subset, perplexity=30.0, k_neighbors=250)

In [None]:
from sklearn.decomposition import PCA
projections = tsne.fit(lr=200, lambd=0.1, n_epochs=1000, exaggeration_epochs=250, exaggeration=22.0)

In [None]:
%matplotlib widget

show_with_slider(
    projections,
    labels=results["subset_labels"],
    interpolate=False,
    steps_per_transition=4,
)

In [None]:
matplotlib.pyplot.close()

In [None]:
visualization_drift_vs_embedding_drift(thesne, embedding_drift_subset)

In [None]:
Y = tsne_with_live_callback(results["subset_embeddings"][10],
                        labels=results["subset_labels"][0],
                        perplexity=30,
                        lr=200,
                        n_iter=10000,
                        interval=5)

In [None]:
from openTSNE import TSNE
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, clear_output

def tsne_with_live_callback(X, labels=None, perplexity=30, lr=200, n_iter=1000, interval=50):
    fig, ax = plt.subplots(figsize=(6, 6))

    def callback(iteration, error, Y):
        if iteration % interval == 0 or iteration == n_iter - 1:
            ax.clear()
            if labels is not None:
                ax.scatter(Y[:, 0], Y[:, 1], c=labels, cmap='tab10', s=5, alpha=0.7)
            else:
                ax.scatter(Y[:, 0], Y[:, 1], s=5, alpha=0.7)
            ax.set_title(f"t-SNE at iter {iteration}")
            ax.set_xticks([])
            ax.set_yticks([])
            clear_output(wait=True)
            display(fig)

    tsne = TSNE(
        n_components=2,
        perplexity=perplexity,
        learning_rate=lr,
        n_iter=n_iter,
        initialization="pca",
        callbacks=callback,
        callbacks_every_iters=interval,
        verbose=False,
    )

    try:
        Y = tsne.fit(X)
    except KeyboardInterrupt:
        print("Interrupted, returning current state.")
        return tsne

    return Y

# OPENTSNE Dynamic

In [None]:
import numpy as np
from openTSNE import TSNE
from sklearn.decomposition import PCA


class DynamicTSNE_2:
    def __init__(self, perplexity=30, n_iter=1000, init='pca', random_state=None):
        """
        init: 'pca', 'random', or 'previous'
        """
        assert init in ['pca', 'random', 'previous']
        self.perplexity = perplexity
        self.n_iter = n_iter
        self.init = init
        self.random_state = random_state

    def fit_transform(self, Xs):
        """
        Xs: List of np.ndarray (each shape: [n_samples, n_features])
        Returns: List of np.ndarray (each shape: [n_samples, 2])
        """
        embeddings = []
        previous_embedding = None

        for i, X in enumerate(Xs):
            print(i)
            if i == 0 or self.init == 'pca':
                init_embedding = PCA(n_components=2).fit_transform(X) if self.init != 'random' else 'random'
            else:
                init_embedding = previous_embedding

            tsne = TSNE(
                n_jobs=-1,
                perplexity=self.perplexity,
                n_iter=self.n_iter,
                initialization=init_embedding,
                random_state=self.random_state,
                verbose=True
            )
            embedding = tsne.fit(X)
            embeddings.append(embedding)
            previous_embedding = embedding

        return embeddings

In [None]:
dynamic_tsne = DynamicTSNE_2(n_iter=500, init='previous', random_state=42)
projections_2 = dynamic_tsne.fit_transform(subset)

In [None]:
%matplotlib widget

show_with_slider(
    projections_2,
    labels=results["subset_labels"],
    interpolate=False,
    steps_per_transition=4,
)

# MODERN DYNAMIC TNSNE

In [None]:
import numpy as np

subset = [results["subset_embeddings"][i] for i in range(0, len(results["subset_embeddings"]), 5)]
samples_per_class = 4
classes = 100  # assuming 1000 samples total and 100 per class
indices = np.concatenate([np.arange(c * 10, c * 10 + samples_per_class) for c in range(classes)])
subset = [emb[:][indices] for emb in subset]
label_subset = [emb[:][indices] for emb in results["subset_labels"]]
len(subset)

In [None]:
label_subset[0]

In [None]:
modern_dynamic_tsne = ModernDynamicTSNE(
    n_epochs=500,
    perplexity=30,
)
projections_3 = modern_dynamic_tsne.fit_transform(subset)

In [None]:
projections_tsne = generate_projections(
    embeddings_list=subset,
    method='tsne',
)

In [None]:
modern_dynamic_tsne

In [None]:
show_multiple_projections_with_slider(
    projections_list=[projections_tsne, projections_3],
    labels=results["subset_labels"],
    titles=["t-SNE", "Dynamic t-SNE"],
    interpolate=False,
    figsize_per_plot=(4, 4),
    dataset=dataset
)

In [None]:
show_with_slider(
    projections_3,
    labels=label_subset,
    interpolate=False,
    steps_per_transition=4,
    dataset=dataset,
    alpha=0.7
)

In [None]:
import numpy as np
from sklearn.decomposition import PCA
import torch
import torch.nn.functional as F
from torch.optim import SGD


class ModernDynamicTSNE:
    def __init__(
        self,
        perplexity=30,
        n_epochs=1000,
        output_dims=2,
        initial_lr=2400,
        final_lr=200,
        lr_switch=250,
        init_stdev=1e-4,
        initial_momentum=0.5,
        final_momentum=0.8,
        momentum_switch=250,
        lmbda=0.0,
        sigma_iters=50,
        verbose=True,
        device=None
    ):
        self.perplexity = perplexity
        self.n_epochs = n_epochs
        self.output_dims = output_dims
        self.initial_lr = initial_lr
        self.final_lr = final_lr
        self.lr_switch = lr_switch
        self.init_stdev = init_stdev
        self.initial_momentum = initial_momentum
        self.final_momentum = final_momentum
        self.momentum_switch = momentum_switch
        self.lmbda = lmbda
        self.sigma_iters = sigma_iters
        self.verbose = verbose
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    def _hbeta(self, D, beta):
        P = torch.exp(-D * beta)
        sumP = torch.sum(P)
        sumP = torch.clamp(sumP, min=1e-8)
        H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
        P = P / sumP
        return H, P

    def _binary_search_perplexity(self, D, tol=1e-5):
        n = D.shape[0]
        sigmas = torch.ones(n, device=self.device)
        P = torch.zeros((n, n), device=self.device)

        logU = np.log(self.perplexity)
        for i in range(n):
            betamin = None
            betamax = None
            beta = sigmas[i]
            Di = D[i][torch.arange(n) != i]
            H, thisP = self._hbeta(Di, beta)

            tries = 0
            while torch.abs(H - logU) > tol and tries < self.sigma_iters:
                if H > logU:
                    betamin = beta
                    beta = beta * 2 if betamax is None else (beta + betamax) / 2
                else:
                    betamax = beta
                    beta = beta / 2 if betamin is None else (beta + betamin) / 2
                H, thisP = self._hbeta(Di, beta)
                tries += 1
            P[i, torch.arange(n) != i] = thisP
        return (P + P.T) / (2 * n)

    def _precompute_Ps(self, Xs):
        Ps = []
        for X in Xs:
            D = torch.cdist(X, X).pow(2)
            P = self._binary_search_perplexity(D)
            Ps.append(P)
        return Ps

    def _compute_cost(self, Ys, Ps):
        total_kl = 0
        for Y, P in zip(Ys, Ps):
            Q_num = 1 / (1 + torch.cdist(Y, Y).pow(2))
            Q_num.fill_diagonal_(0)
            Q = Q_num / Q_num.sum()
            kl = torch.sum(P * torch.log((P + 1e-8) / (Q + 1e-8)))
            total_kl += kl
        smoothness = sum((Ys[i] - Ys[i + 1]).pow(2).sum() for i in range(len(Ys) - 1))
        return total_kl + self.lmbda * smoothness / (2 * Ys[0].shape[0])

    def fit_transform(self, Xs_np):
        Xs = [torch.tensor(X, device=self.device, dtype=torch.float32) for X in Xs_np]
        T = len(Xs)
        N = Xs[0].shape[0]

        # Init Ys with PCA
        Ys = [
            torch.tensor(PCA(n_components=self.output_dims).fit_transform(X.cpu().numpy()),
                         device=self.device, dtype=torch.float32, requires_grad=True)
            for X in Xs
        ]

        # Precompute all P matrices once
        Ps = self._precompute_Ps(Xs)

        optimizer = SGD(Ys, lr=self.initial_lr, momentum=self.initial_momentum)

        for epoch in range(self.n_epochs):
            if epoch == self.lr_switch:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = self.final_lr
            if epoch == self.momentum_switch:
                for param_group in optimizer.param_groups:
                    param_group['momentum'] = self.final_momentum

            optimizer.zero_grad()
            loss = self._compute_cost(Ys, Ps)
            loss.backward()
            optimizer.step()

            if self.verbose and (epoch % 100 == 0 or epoch == self.n_epochs - 1):
                print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

        return [Y.detach().cpu().numpy() for Y in Ys]