In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt

import torch

# need to import classes in order for torch.load to work
from latent_geometry.model.mnist_vae import (
    load_decoder,
    load_encoder,
    EncoderVAE,
    DecoderVAE,
)
from latent_geometry.mapping import TorchModelMapping, BaseTorchModelMapping
from latent_geometry.visual.plotly import (
    create_topology_fig,
)
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric
from latent_geometry.path import ManifoldPath
from latent_geometry.data import load_mnist_dataset
import os

In [4]:
# raise Exception("double check that we wont use already taken gpu ($ nvidia-smi)")
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [5]:
dataset = load_mnist_dataset(split="test")
img, label = dataset[0]
print(img.shape, label)

torch.Size([1, 32, 32]) 7


In [6]:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda")

ENCODER = load_encoder(DEVICE)
DECODER = load_decoder(DEVICE)

z, mu, log_var = ENCODER(dataset[0][0].to(DEVICE)[None, :])
reconstruction = DECODER(z)
reconstruction.shape

torch.Size([1, 1, 32, 32])

In [7]:
SOLVER_TOL = 0.1

In [8]:
ambient_metric = EuclideanMetric(1024)
slow_latent_mapping = BaseTorchModelMapping(DECODER, (2,), (1, 1, 32, 32))
fast_latent_mapping = TorchModelMapping(DECODER, (2,), (1, 1, 32, 32))

slow_manifold_mnist = LatentManifold(
    slow_latent_mapping, ambient_metric, solver_tol=SOLVER_TOL
)
fast_manifold_mnist = LatentManifold(
    fast_latent_mapping, ambient_metric, solver_tol=SOLVER_TOL
)

In [9]:
def create_straight_path(
    from_: np.ndarray, to_: np.ndarray, manifold: LatentManifold
) -> ManifoldPath:
    def x_fun(t: float) -> np.ndarray:
        return from_ + (to_ - from_) * t

    def v_fun(t: float) -> np.ndarray:
        return from_ - to_

    return ManifoldPath(
        x_fun, v_fun, manifold.metric, manifold._euclidean_latent_metric
    )


def create_latent_path(
    from_: np.ndarray, theta: float, length: float, manifold: LatentManifold
) -> ManifoldPath:
    return manifold.path_given_direction(
        from_, np.array([np.cos(theta), np.sin(theta)]), length
    )


def create_geodesic_path(
    from_: np.ndarray, to_: np.ndarray, manifold: LatentManifold
) -> ManifoldPath:
    return manifold.geodesic(from_, to_)

In [10]:
START = np.array([0, 0.5])

In [12]:
slow_latent_path = create_latent_path(START, np.pi * 0.9, 10.0, slow_manifold_mnist)

In [13]:
fast_latent_path = create_latent_path(START, np.pi * 0.9, 10.0, fast_manifold_mnist)

In [14]:
straight_path = create_straight_path(
    fast_latent_path(0), fast_latent_path(1), fast_manifold_mnist
)

In [18]:
# 8m 30s
geo1 = create_geodesic_path(
    slow_latent_path(0), slow_latent_path(1), slow_manifold_mnist
)

In [None]:
# 5m
geo2 = create_geodesic_path(
    fast_latent_path(0), fast_latent_path(1), fast_manifold_mnist
)

In [None]:
a, b = 5, 8.5
(b - a) / b

0.4117647058823529

In [15]:
import cProfile
import pstats

with cProfile.Profile() as pr:
    # ... do something ...
    create_geodesic_path(fast_latent_path(0), fast_latent_path(1), fast_manifold_mnist)

In [16]:
pstats.Stats(pr).sort_stats("tottime").print_stats()

         7901289 function calls (7690775 primitive calls) in 242.514 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    14516  149.577    0.010  149.577    0.010 {method 'cpu' of 'torch._C._TensorBase' objects}
    29032   40.155    0.001   40.155    0.001 {built-in method torch.tensor}
    10887   16.868    0.002   16.868    0.002 {method 'run_backward' of 'torch._C._EngineBase' objects}
     7322   11.326    0.002   11.326    0.002 {built-in method numpy.zeros}
    18145    6.508    0.000    6.508    0.000 {method 'to' of 'torch._C._TensorBase' objects}
    58064    3.024    0.000    3.024    0.000 {built-in method torch.conv_transpose2d}
     3629    2.747    0.001   15.936    0.004 /home/quczer/repos/latent-geometry/src/latent_geometry/metric/abstract.py:159(metric_matrix)
    14516    1.040    0.000    1.040    0.000 {built-in method torch._C._nn.linear}
    43548    0.664    0.000    0.664    0.000 {built-in method t

<pstats.Stats at 0x7f463ddda7c0>