# Embedding Visualization

## Import Data

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
import random

In [None]:
# ==== MNIST ========
#run = "mnist_MLP_32_0.9595"

# ==== CIFAR 10 ========

#run = "cifar10_MLP_128_0.5672"

run = "cifar10_CNN_128_0.8977"

#run = "cifar10_ViT_128_0.6311"
#run = "cifar10_ViT_128_0.6530"

In [None]:
from data_manager import load_training_data

results = load_training_data(run)
results["embedding_drifts"] = {int(k): results["embedding_drifts"][k] for k in sorted(results["embedding_drifts"].keys(), key=int)}

## Visualize Training

In [None]:
from train_viz import _plot_loss_accuracy, _plot_gradients, _plot_embedding_drift

#fig, axs = plt.subplots(2, 2, figsize=(10, 8))
#epochs = len(results["train_losses"])
#_plot_loss_accuracy(axs[0][0], epochs-1, epochs, results["train_losses"], results["val_losses"], results["train_accuracies"], results["val_accuracies"])
#_plot_gradients(axs[0][1], range(0, len(results["gradient_norms"])),results["gradient_norms"], results["max_gradients"], results["grad_param_ratios"], 20)
#_plot_embedding_drift(axs[1][0], results["embedding_drifts"])

In [None]:
%matplotlib widget
from train_viz import _plot_embedding_drift

embedding_drifts = results["embedding_drifts"].copy()
# fig, axs = plt.subplots(1, 1, figsize=(10, 4))

# Plot 2x Drifts
# axs.plot(range(1, len(embedding_drifts[1]) + 1), np.array(embedding_drifts[1]) * 2, color="green", label="2x Drift 1", alpha=0.3)
# axs.plot(range(1, len(embedding_drifts[2]) + 1), np.array(embedding_drifts[2]) * 2, color="blue", label="2x Drift 2", alpha=0.3)
# axs.plot(range(1, len(embedding_drifts[4]) + 1), np.array(embedding_drifts[4]) * 2, color="orange", label="2x Drift 2", alpha=0.3)
# axs.plot(range(1, len(embedding_drifts[8]) + 1), np.array(embedding_drifts[8]) * 2, color="red", label="2x Drift 2", alpha=0.3)
# _plot_embedding_drift(axs, embedding_drifts)
# 
# plt.legend()
# plt.show()

# Visualizations PCA

In [None]:
from train_viz import generate_projections, animate_projections, show_with_slider, show_multiple_projections_with_slider, visualization_drift_vs_embedding_drift

In [None]:
projections_pca_first = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='first',
)
projections_pca_last = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='last',
)
projections_pca_all = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='all',
)

In [None]:
projections_pca_window = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='pca',
    pca_fit_basis='window',
    window_size=20,
)

In [None]:
show_with_slider(
    projections_pca_window,
    labels=results["subset_labels"],
    interpolate=True,
    steps_per_transition=2,
)

In [None]:
show_multiple_projections_with_slider(
    projections_list=[projections_pca_first, projections_pca_last, projections_pca_all, projections_pca_window],
    labels=results["subset_labels"],
    titles=["PCA on first", "PCA on last", "PCA on all", "PCA sliding window"],
    interpolate=True,
    steps_per_transition=1,
    figsize_per_plot=(4, 4),
)

In [None]:
visualization_drift_vs_embedding_drift(projections_pca_window, embedding_drifts)

In [None]:
import matplotlib
matplotlib.pyplot.close()

In [None]:
from train_viz import _calculate_embedding_drift

def adjust_visualization_speed(projections, embedding_drifts, drift_key):
    """
    Adjusts the visualization movement speed iteratively to align with the maximum movement of a specified drift.

    Args:
        projections (list of np.ndarray): Low-dimensional projections.
        embedding_drifts (dict of np.ndarray): High-dimensional embedding drift values.
        target_drift_key (str): Key in embedding_drifts to use as the reference for speed adjustment.

    Returns:
        list of np.ndarray: Adjusted projections with aligned speed.
    """
    # Extract the target drift
    target_drift = np.asarray(embedding_drifts[drift_key]).flatten()
    scaling_difference = np.mean(_calculate_embedding_drift(projections)[drift_key][drift_key-1:] / embedding_drifts[drift_key][drift_key-1:])
    print(f"Scaling difference: {scaling_difference}")
    
    # Apply scaling iteratively
    adjusted_projections = projections[0:drift_key]  # Start with the first projection as-is
    changes = 0
    for i in range(drift_key, len(projections)):
        # Calculate drift for this step
        current_drift = projections[i] - adjusted_projections[-1]
        vis_drift_step = np.linalg.norm(current_drift, axis=1).mean()
        target_drift_step = np.abs(target_drift[i - 1])  # Reference target drift for this step

        # Determine scaling factor
        if vis_drift_step == 0 or vis_drift_step < target_drift_step * scaling_difference:
            scaling_factor = 1.0
        else:
            scaling_factor = target_drift_step / vis_drift_step * scaling_difference
            changes += 1
            #print(f"{i}: {vis_drift_step} > {target_drift_step}")
            #print(scaling_factor)

        # Apply scaling and update
        adjusted_step = adjusted_projections[-1] + current_drift * scaling_factor
        adjusted_projections.append(adjusted_step)
        
    print(f"{changes / len(adjusted_projections)}% changes ({changes})")
    return adjusted_projections

In [None]:
adapted = adjust_visualization_speed(projections_pca_window, embedding_drifts, 1)

In [None]:
visualization_drift_vs_embedding_drift(adapted, embedding_drifts)

In [None]:
show_multiple_projections_with_slider(
    projections_list=[projections_pca_window, adapted],
    labels=results["subset_labels"],
    titles=["PCA sliding window", "slowed"],
    interpolate=False,
    steps_per_transition=2,
    figsize_per_plot=(4, 4),
)

In [None]:
denoised = denoise_projections(adapted, window_size=5, blend=1)

In [None]:
show_multiple_projections_with_slider(
    projections_list=[projections_pca_window, adapted, denoised],
    labels=results["subset_labels"],
    titles=["PCA sliding window", "adapted", "adapted+denoised"],
    interpolate=False,
    figsize_per_plot=(4, 4),
)

In [None]:
visualization_drift_vs_embedding_drift(projections_pca_window_denoised, embedding_drifts)

# t-SNE Visualization

In [None]:
projections_tsne = generate_projections(
    embeddings_list=results["subset_embeddings"],
    method='tsne',
)

In [None]:
show_multiple_projections_with_slider(
    projections_list=[denoised, projections_tsne],
    labels=results["subset_labels"],
    titles=["PCA", "t-SNE"],
    interpolate=True,
    steps_per_transition=2,
    figsize_per_plot=(5, 5),
    shared_axes=False,
)

In [None]:
visualization_drift_vs_embedding_drift(projections_tsne, embedding_drifts)

In [None]:
adapted_tsne = adjust_visualization_speed(projections_tsne, embedding_drifts, 1)

In [None]:
denoised_tsne = denoise_projections(adapted_tsne, window_size=5, blend=1)

In [None]:
show_multiple_projections_with_slider(
    projections_list=[denoised, denoised_tsne, projections_tsne],
    labels=results["subset_labels"],
    titles=["PCA", "t-SNE slowed", "t-SNE"],
    interpolate=False,
    steps_per_transition=1,
    figsize_per_plot=(5, 5),
    shared_axes=False,
)

In [None]:
visualization_drift_vs_embedding_drift(denoised_tsne, embedding_drifts)