In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import plotly.express as px

import torch

from latent_geometry.model.mnist_vae import (
    load_decoder,
    load_encoder,
)
from latent_geometry.mapping import TorchModelMapping
from viz.plotly import (
    plot_traces,
    draw_balls,
    create_dot_background,
)
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric
from latent_geometry.path import ManifoldPath, Path
from latent_geometry.data import load_mnist_dataset
from latent_geometry.utils import project, lift
from viz.calc import create_circles, create_radials, create_loop
from latent_geometry.config import FIGURES_DIR
import os
from tqdm import tqdm

## prep

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [4]:
dataset = load_mnist_dataset(split="train")
images = torch.concat([img for img, _ in dataset])
labels = torch.tensor([label for _, label in dataset])
images.shape, labels.shape

(torch.Size([60000, 32, 32]), torch.Size([60000]))

In [6]:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cpu")
model_name, latent_dim = "beta_10", 2

# DEVICE = torch.device("cuda")
ENCODER = load_encoder(DEVICE, f"{model_name}_encoder.pt", latent_dim=latent_dim)
DECODER = load_decoder(DEVICE, f"{model_name}_decoder.pt", latent_dim=latent_dim)

z = ENCODER.sample(dataset[0][0].to(DEVICE)[None, :])
reconstruction = DECODER.decode(z)
reconstruction.shape

torch.Size([1, 1, 32, 32])

In [9]:
SOLVER_TOL = 0.01
ambient_metric = EuclideanMetric()
latent_mapping = TorchModelMapping(
    DECODER,
    (
        -1,
        2,
    ),
    (-1, 1, 32, 32),
    batch_size=10_000,
    call_fn=DECODER.decode,
)
manifold_mnist = LatentManifold(
    latent_mapping, ambient_metric, solver_atol=SOLVER_TOL, bvp_n_mesh_nodes=2_000
)

In [10]:
mus, sts = ENCODER(images.unsqueeze(1).to(DEVICE))
mus.shape, sts.shape

(torch.Size([60000, 2]), torch.Size([60000, 1]))

## figures

In [30]:
frac = 0.05
idx = np.random.choice(a=len(mus), size=int(len(mus) * frac), replace=False)
background_trace = create_dot_background(
    mus.detach().cpu()[idx], labels.numpy()[idx], opacity=1
)
fig = plot_traces([background_trace])
fig

In [None]:
NUM = 20
SPAN = 2.5
RADIUS = 1.0
N_DIV = 8

xs, ys = np.meshgrid(
    np.linspace(-SPAN, SPAN, num=NUM), np.meshgrid(np.linspace(-SPAN, SPAN, num=NUM))
)
xs.shape

In [None]:
CENTRES = np.vstack((xs.reshape(-1), ys.reshape(-1))).T
CENTRES.shape

In [None]:
def create_exact_ball(centre, n_div: int, radius: float):
    radials = create_radials(
        centre=centre, divisions=n_div, manifold=manifold_mnist, length=radius
    )
    circles = create_circles(radials, 1)
    return Path(circles[0])


def create_local_ball(centre, n_div: int, radius: float):
    thetas = np.linspace(0, 2 * np.pi, num=n_div, endpoint=False)
    points = np.vstack((radius * np.cos(thetas), radius * np.sin(thetas))).T
    centres = centre[None, :].repeat(n_div, 0)
    ambient_lengths = manifold_mnist.metric.vector_length(points, centres)
    print(manifold_mnist.metric.metric_matrix(centre[None, :]))
    points_rescaled = points * ambient_lengths[:, None]
    loop = create_loop(points_rescaled)
    return loop

In [None]:
path = create_local_ball(np.array([0, 0]), 9, 0.4)
path

In [None]:
# draw_balls([path])

In [None]:
BALLS = [
    create_exact_ball(centre=centre, n_div=N_DIV, radius=RADIUS)
    for centre in tqdm(CENTRES)
]

In [None]:
spiders_fig = draw_balls(BALLS, background_trace)
spiders_fig

In [None]:
# spiders_fig.write_html(FIGURES_DIR / "mnist" / "html" / "exact_balls.html")

## trash

In [None]:
fig_px = px.scatter(
    x=mus[:, 0].detach().cpu(),
    y=mus[:, 1].detach().cpu(),
    color=labels.numpy().astype(str),
    opacity=0.05,
)