In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt

import torch

from latent_geometry.model.mnist_vae import (
    load_decoder,
    load_encoder,
)
from viz.plotly import (
    plot_traces,
    draw_spiders,
    create_digit_background,
    create_dot_background,
    _path_to_trace,
)
from latent_geometry.mapping import TorchModelMapping, Mapping
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric, EuclideanPullbackMetric
from latent_geometry.path import ManifoldPath
from latent_geometry.data import load_mnist_dataset
from latent_geometry.utils import project, lift
import os

from typing import Union, Callable, Optional

from scipy.interpolate import splev, splprep
from functools import partial

In [None]:
# raise Exception("double check that we wont use already taken gpu ($ nvidia-smi)")
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
dataset = load_mnist_dataset(split="train")
images = torch.concat([img for img, _ in dataset])
labels = torch.tensor([label for _, label in dataset])
images.shape, labels.shape

In [None]:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cuda")
model_name, latent_dim = "beta_10", 2

# DEVICE = torch.device("cuda")
ENCODER = load_encoder(DEVICE, f"{model_name}_encoder.pt", latent_dim=latent_dim)
DECODER = load_decoder(DEVICE, f"{model_name}_decoder.pt", latent_dim=latent_dim)

z = ENCODER.sample(dataset[0][0].to(DEVICE)[None, :])
reconstruction = DECODER.decode(z)
reconstruction.shape

In [None]:
mus, sts = ENCODER(images.unsqueeze(1).to(DEVICE))
mus.shape, sts.shape

In [None]:
SOLVER_TOL = 0.001
ambient_metric = EuclideanMetric()
latent_mapping = TorchModelMapping(
    DECODER, (2,), (1, 32, 32), batch_size=10_000, call_fn=DECODER.decode
)
manifold_mnist = LatentManifold(
    latent_mapping, ambient_metric, solver_tol=SOLVER_TOL, bvp_n_mesh_nodes=2_000
)

In [None]:
mus, sts = ENCODER(images.unsqueeze(1).to(DEVICE))
mus.shape, sts.shape

In [None]:
def create_straight_path(from_: np.ndarray, to_: np.ndarray) -> ManifoldPath:
    def x_fun(t: float) -> np.ndarray:
        return from_ + (to_ - from_) * t

    return ManifoldPath(x_fun, manifold_mnist.metric)


def create_latent_path(from_: np.ndarray, theta: float, length: float) -> ManifoldPath:
    return manifold_mnist.geodesic(
        from_, np.array([np.cos(theta), np.sin(theta)]), length
    )


def create_geodesic_path(from_: np.ndarray, to_: np.ndarray) -> ManifoldPath:
    return manifold_mnist.shortest_path(from_, to_)

In [None]:
def show_path_in_ambient(path: ManifoldPath, n_points: int = 9):
    fig, axes = plt.subplots(1, n_points, figsize=(1.5 * n_points, 2))
    for i in range(n_points):
        t = i / (n_points - 1)
        latent_dist = path.manifold_length(0, t)
        euclidean_dist = path.euclidean_length(0, t)
        ambient_dist = path.ambient_path.euclidean_length(0, t, dt=0.01)

        image = project(latent_mapping)(path(t)).reshape((32, 32))
        ax = axes[i]
        ax.imshow(image, cmap="gray")
        ax.set_title(
            (
                f"v: Euc: {euclidean_dist:.3f}, P-B: {latent_dist: .3f}\n"
                # f"d: Euc: {euclidean_dist_diff:.3f}, P-B: {latent_dist_diff: .3f}\n"
                f"ambient dist: {ambient_dist:.3f}"
            ),
            fontsize=8,
        )
        ax.set_axis_off()

    plt.tight_layout()
    plt.show()

In [None]:
START = np.array([-1, -1])
# img = project(latent_mapping)(START).reshape(32, 32)
# fig = px.imshow(img)
# fig

## figures

In [None]:
frac = 0.1
idx = np.random.choice(a=len(mus), size=int(len(mus) * frac), replace=False)
dot_background = create_dot_background(
    mus.detach().cpu()[idx], labels.numpy()[idx], opacity=0.2
)
fig = plot_traces([dot_background])
# fig

In [None]:
latent_path = create_latent_path(START, np.pi * 1 / 4, 10.0)
straight_path = create_straight_path(latent_path(0), latent_path(1))
latent_path(0), latent_path(1)

In [None]:
heatmaps = create_digit_background(num=30, opacity=0.5, mapping=latent_mapping)

In [None]:
fig = plot_traces(heatmaps + [dot_background])
# fig

In [None]:
latent_trace = _path_to_trace(
    latent_path, color="green", legend_group="geodesic", show_legend=True
)
straight_trace = _path_to_trace(
    straight_path, color="black", legend_group="straight", show_legend=True
)

In [None]:
fig = plot_traces(heatmaps + [latent_trace, straight_trace])
fig

In [None]:
# geodesic = create_geodesic_path(latent_path(0), latent_path(1))

In [None]:
def plot_paths(
    paths: list[tuple[ManifoldPath, str]],
    ts: np.ndarray = np.linspace(0, 1),
):
    for path, name in paths:
        path_pts = lift(path)(ts)
        plt.plot(path_pts[:, 0], path_pts[:, 1], label=name)

    plt.legend()


def plot_accelerations(
    paths: list[tuple[ManifoldPath, str]],
    scale: float,
    num: int = 20,
    dt: float = 0.001,
):
    for path, name in paths:
        for t in np.linspace(0, 1 - dt, num=num):
            x, x2 = path(t), path(t + dt)
            v = (x2 - x) / dt
            acc = (
                project(path._manifold_metric.acceleration)(position=x, velocity=v)
                * scale
            )
            plt.arrow(x[0], x[1], acc[0], acc[1])

In [None]:
PATHS = [
    (straight_path, "straight"),
    # (geodesic, "geodesic"),
    (latent_path, "latent"),
]

In [None]:
plt.figure(figsize=(10, 6))
plot_paths(PATHS)
# plot_accelerations(ALL_PATHS, scale=0.015, num=20, dt=0.001)
plt.show()

In [None]:
# straight_path.ambient_path.euclidean_length(0, 0.1)
straight_path.ambient_path.euclidean_length(0, 0.1, 0.001)
# straight_path.euclidean_length(0, 1, 0.0001)
# straight_path.euclidean_length(0, 1,)

In [None]:
# pths = [(straight_path, "straight")] + SPIDER_PATHS
pths = PATHS
print([name for _, name in pths])
for path, name in pths:
    show_path_in_ambient(path, n_points=9)

In [None]:
fig_px = px.scatter(
    x=mus[:, 0].detach().cpu(),
    y=mus[:, 1].detach().cpu(),
    color=labels.numpy().astype(str),
    opacity=0.05,
)

In [None]:
fig_px