In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt

import torch

# need to import classes in order for torch.load to work
from latent_geometry.model.mnist_vae import (
    load_decoder,
    load_encoder,
    EncoderVAE,
    DecoderVAE,
)
from latent_geometry.mapping import TorchModelMapping, BaseTorchModelMapping
from latent_geometry.visual.plotly import (
    create_topology_fig,
)
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric
from latent_geometry.path import ManifoldPath
from latent_geometry.data import load_mnist_dataset
import os

In [3]:
raise Exception("double check that we wont use already taken gpu ($ nvidia-smi)")
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [4]:
dataset = load_mnist_dataset(split="test")
img, label = dataset[0]
print(img.shape, label)

torch.Size([1, 32, 32]) 7


In [5]:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda")

ENCODER = load_encoder(DEVICE)
DECODER = load_decoder(DEVICE)

z, mu, log_var = ENCODER(dataset[0][0].to(DEVICE)[None, :])
reconstruction = DECODER(z)
reconstruction.shape

torch.Size([1, 1, 32, 32])

In [6]:
SOLVER_TOL = 0.1
N_MESH_NODES = 100

In [7]:
ambient_metric = EuclideanMetric(1024)
slow_latent_mapping = BaseTorchModelMapping(DECODER, (2,), (1, 32, 32))
fast_latent_mapping = TorchModelMapping(DECODER, (2,), (1, 32, 32), batch_size=None)

slow_manifold_mnist = LatentManifold(
    slow_latent_mapping,
    ambient_metric,
    solver_tol=SOLVER_TOL,
    bvp_n_mesh_nodes=N_MESH_NODES,
)
fast_manifold_mnist = LatentManifold(
    fast_latent_mapping,
    ambient_metric,
    solver_tol=SOLVER_TOL,
    bvp_n_mesh_nodes=N_MESH_NODES,
)

In [8]:
def create_straight_path(
    from_: np.ndarray, to_: np.ndarray, manifold: LatentManifold
) -> ManifoldPath:
    def x_fun(t: float) -> np.ndarray:
        return from_ + (to_ - from_) * t

    def v_fun(t: float) -> np.ndarray:
        return from_ - to_

    return ManifoldPath(
        x_fun, v_fun, manifold.metric, manifold._euclidean_latent_metric
    )


def create_latent_path(
    from_: np.ndarray, theta: float, length: float, manifold: LatentManifold
) -> ManifoldPath:
    return manifold.geodesic(from_, np.array([np.cos(theta), np.sin(theta)]), length)


def create_geodesic_path(
    from_: np.ndarray, to_: np.ndarray, manifold: LatentManifold
) -> ManifoldPath:
    return manifold.shortest_path(from_, to_)

In [9]:
START = np.array([0, 0.5])

In [10]:
slow_latent_path = create_latent_path(START, np.pi * 0.9, 10.0, slow_manifold_mnist)

In [11]:
fast_latent_path = create_latent_path(START, np.pi * 0.9, 10.0, fast_manifold_mnist)

In [12]:
# straight_path = create_straight_path(
#     fast_latent_path(0), fast_latent_path(1), fast_manifold_mnist
# )

In [13]:
# 8m 30s -> 2m 30s
geo1 = create_geodesic_path(
    slow_latent_path(0), slow_latent_path(1), slow_manifold_mnist
)

In [14]:
geo2 = create_geodesic_path(
    fast_latent_path(0), fast_latent_path(1), fast_manifold_mnist
)

# 8m 30s -> 2m 30s ----> 0.6s <3

In [15]:
import cProfile
import pstats

with cProfile.Profile() as pr:
    # ... do something ...
    create_geodesic_path(fast_latent_path(0), fast_latent_path(1), fast_manifold_mnist)

In [18]:
pstats.Stats(pr).sort_stats("tottime").print_stats(10)

         244606 function calls (236350 primitive calls) in 0.580 seconds

   Ordered by: internal time
   List reduced from 450 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      504    0.267    0.001    0.267    0.001 {built-in method torch.conv_transpose2d}
      126    0.047    0.000    0.082    0.001 {built-in method torch._C._nn.linear}
      378    0.026    0.000    0.026    0.000 {built-in method torch.relu}
      126    0.013    0.000    0.013    0.000 {method 'cpu' of 'torch._C._TensorBase' objects}
      126    0.013    0.000    0.546    0.004 /home/quczer/repos/latent-geometry/src/latent_geometry/utils.py:31(__wrapper)
      321    0.012    0.000    0.012    0.000 {built-in method numpy.core._multiarray_umath.c_einsum}
      126    0.011    0.000    0.011    0.000 {built-in method torch.sigmoid}
       63    0.010    0.000    0.010    0.000 {built-in method torch.mm}
       63    0.010    0.000    0.010    0.000 /home

<pstats.Stats at 0x7fbe2d918070>