In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import plotly.express as px

import torch

from latent_geometry.model.mnist_vae import (
    load_decoder,
    load_encoder,
)
from latent_geometry.mapping import TorchModelMapping
from viz.plotly import (
    plot_traces,
    draw_balls,
    create_dot_background,
    create_scalar_field,
)
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric
from latent_geometry.path import ManifoldPath, Path
from latent_geometry.data import load_mnist_dataset
from latent_geometry.utils import project, lift
from viz.calc import create_circles, create_radials
from latent_geometry.config import FIGURES_DIR
import os

## prep

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [4]:
dataset = load_mnist_dataset(split="train")
images = torch.concat([img for img, _ in dataset])
labels = torch.tensor([label for _, label in dataset])
images.shape, labels.shape

(torch.Size([60000, 32, 32]), torch.Size([60000]))

In [23]:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cpu")
model_name, latent_dim = "beta_10", 2
DIR = FIGURES_DIR / "mnist" / "html" / model_name
DIR.mkdir(exist_ok=True)

# DEVICE = torch.device("cuda")
ENCODER = load_encoder(DEVICE, f"{model_name}_encoder.pt", latent_dim=latent_dim)
DECODER = load_decoder(DEVICE, f"{model_name}_decoder.pt", latent_dim=latent_dim)

z = ENCODER.sample(dataset[0][0].to(DEVICE)[None, :])
reconstruction = DECODER.decode(z)
reconstruction.shape

torch.Size([1, 1, 32, 32])

In [24]:
SOLVER_TOL = 0.001
ambient_metric = EuclideanMetric()
latent_mapping = TorchModelMapping(
    DECODER, (2,), (1, 32, 32), batch_size=10_000, call_fn=DECODER.decode
)
manifold_mnist = LatentManifold(
    latent_mapping, ambient_metric, solver_tol=SOLVER_TOL, bvp_n_mesh_nodes=2_000
)

In [25]:
mus, sts = ENCODER(images.unsqueeze(1).to(DEVICE))
mus.shape, sts.shape

(torch.Size([60000, 2]), torch.Size([60000, 1]))

## figures

In [26]:
frac = 0.3
idx = np.random.choice(a=len(mus), size=int(len(mus) * frac), replace=False)
background_trace = create_dot_background(
    mus.detach().cpu()[idx], labels.numpy()[idx], opacity=0.1
)
fig = plot_traces([background_trace])
# fig

In [27]:
NUM = 30
SPAN = 2.5
RADIUS = 1.0
N_DIV = 8

xs, ys = np.meshgrid(
    np.linspace(-SPAN, SPAN, num=NUM), np.meshgrid(np.linspace(-SPAN, SPAN, num=NUM))
)
xs.shape

(30, 30)

In [28]:
CENTRES = np.vstack((xs.reshape(-1), ys.reshape(-1))).T
CENTRES.shape

(900, 2)

In [29]:
def calc_corr(points: np.ndarray) -> np.ndarray:
    metrics = manifold_mnist.metric.metric_matrix(points)
    return metrics[:, 0, 1] / np.sqrt(metrics[:, 0, 0] * metrics[:, 1, 1])

In [30]:
corr_field = create_scalar_field(
    calc_corr, num=500, opacity=1.0, field_title="basis vector correlation"
)

In [31]:
corr_fig = plot_traces([background_trace, corr_field])
# corr_fig

In [32]:
def calc_changes(points: np.ndarray) -> np.ndarray:
    metrics = manifold_mnist.metric.metric_matrix(points)
    eigvals = np.linalg.eigvalsh(metrics)
    return eigvals[:, -1]

In [33]:
magnitude_field = create_scalar_field(
    calc_changes,
    num=500,
    opacity=1.0,
    field_title="magnitude of change",
    cmap="hot",
)

In [34]:
magnitude_fig = plot_traces([magnitude_field])
# magnitude_fig

In [35]:
corr_fig.write_html(DIR / "corr.html")

In [36]:
magnitude_fig.write_html(DIR / "eigvals.html")