In [1]:
%load_ext autoreload
%autoreload 2

In [55]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.basedatatypes import BaseTraceType

import torch

from latent_geometry.model.mnist_vae import (
    load_decoder,
    load_encoder,
)
from latent_geometry.mapping import TorchModelMapping
from latent_geometry.viz.plotly import (
    plot_traces,
    draw_spiders,
)
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric
from latent_geometry.path import ManifoldPath, Path
from latent_geometry.data import load_mnist_dataset
from latent_geometry.utils import project, lift
from latent_geometry.viz.calc import create_circles, create_radials
from latent_geometry.config import FIGURES_DIR
import os

## prep

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [4]:
dataset = load_mnist_dataset(split="train")
images = torch.concat([img for img, _ in dataset])
labels = torch.tensor([label for _, label in dataset])
images.shape, labels.shape

(torch.Size([60000, 32, 32]), torch.Size([60000]))

In [5]:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")
model_name, latent_dim = "beta_10", 2

DEVICE = torch.device("cuda")
ENCODER = load_encoder(DEVICE, f"{model_name}_encoder.pt", latent_dim=latent_dim)
DECODER = load_decoder(DEVICE, f"{model_name}_decoder.pt", latent_dim=latent_dim)

z = ENCODER.sample(dataset[0][0].to(DEVICE)[None, :])
reconstruction = DECODER.decode(z)
reconstruction.shape

torch.Size([1, 1, 32, 32])

In [6]:
SOLVER_TOL = 0.001
ambient_metric = EuclideanMetric()
latent_mapping = TorchModelMapping(
    DECODER, (2,), (1, 32, 32), batch_size=10_000, call_fn=DECODER.decode
)
manifold_mnist = LatentManifold(
    latent_mapping, ambient_metric, solver_tol=SOLVER_TOL, bvp_n_mesh_nodes=2_000
)

In [7]:
mus, sts = ENCODER(images.unsqueeze(1).to(DEVICE))
mus.shape, sts.shape

(torch.Size([60000, 2]), torch.Size([60000, 1]))

## figures

In [23]:
def create_background_trace(mus: np.ndarray, labels: np.ndarray) -> BaseTraceType:
    cmap = np.array(px.colors.qualitative.G10)
    colors = cmap[labels]
    return go.Scatter(
        x=mus[:, 0],
        y=mus[:, 1],
        mode="markers",
        marker=dict(color=colors, opacity=0.5),
        name="mnist",
    )

In [24]:
frac = 0.1
idx = np.random.choice(a=len(mus), size=int(len(mus) * frac), replace=False)
background_trace = create_background_trace(mus.detach().cpu()[idx], labels.numpy()[idx])
fig = plot_traces([background_trace])
# fig

In [47]:
NUM = 3
SPAN = 1
xs, ys = np.meshgrid(
    np.linspace(-SPAN, SPAN, num=NUM), np.meshgrid(np.linspace(-SPAN, SPAN, num=NUM))
)
xs.shape

(3, 3)

In [49]:
CENTRES = np.vstack((xs.reshape(-1), ys.reshape(-1))).T
CENTRES.shape

(9, 2)

In [53]:
N_DIV, N_CIRCLES = 8, 4
LENGTH = 8.0


def create_spider(centre):
    radials = create_radials(
        centre=centre, divisions=N_DIV, manifold=manifold_mnist, length=LENGTH
    )
    circles = create_circles(radials, N_CIRCLES)
    return radials + circles


spiders = [create_spider(c) for c in CENTRES]
len(spiders)

9

In [54]:
spiders_fig = draw_spiders(spiders, background_trace)
spiders_fig

In [56]:
spiders_fig.write_html(FIGURES_DIR / "html" / "mnist_spiders.html")

## trash

In [None]:
fig_px = px.scatter(
    x=mus[:, 0].detach().cpu(),
    y=mus[:, 1].detach().cpu(),
    color=labels.numpy().astype(str),
    opacity=0.05,
)