In [52]:
# !rm -rf latent-geometry
# !git clone -b final-project-presentation --single-branch https://github.com/quczer/latent-geometry.git
# # !pip uninstall numpy -y
# !pip install -e latent-geometry[dev]
# # !python latent-geometry/setup.py install

In [53]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [54]:
import numpy as np
import torchvision.transforms as transforms

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader

from latent_geometry.model.mnist_vae import EncoderVAE, DecoderVAE
from latent_geometry.mapping import TorchModelMapping
from latent_geometry.visual.plotly import (
    create_topology_fig,
    create_topology_fig_geodesics,
)
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric

import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

from latent_geometry.config import NOTEBOOKS_DIR

In [55]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# initialize the model

encoder = EncoderVAE(init_channels=8, latent_dim=2, kernel_size=4, image_channels=1).to(
    device
)
decoder = DecoderVAE(init_channels=8, latent_dim=2, kernel_size=4, image_channels=1).to(
    device
)
params = [param for param in encoder.parameters()] + [
    param for param in decoder.parameters()
]

# set the learning parameters
lr = 0.001
epochs = 100
batch_size = 64
optimizer = optim.Adam(params, lr=lr)
criterion = nn.MSELoss(reduction="sum")

In [56]:
transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ]
)
# training set and train data loader
trainset = torchvision.datasets.MNIST(
    root=NOTEBOOKS_DIR / "input", train=True, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
# validation set and validation data loader
testset = torchvision.datasets.MNIST(
    root=NOTEBOOKS_DIR / "input", train=False, download=True, transform=transform
)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

In [57]:
batch_size = 64

transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ]
)
testset = torchvision.datasets.MNIST(
    root=NOTEBOOKS_DIR / "input", train=False, download=True, transform=transform
)
testloader = DataLoader(testset, batch_size=500, shuffle=False)

In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = torch.load(NOTEBOOKS_DIR / "output" / "encoder", map_location=device)
decoder = torch.load(NOTEBOOKS_DIR / "output" / "decoder", map_location=device)

x = next(iter(trainloader))
z, mu, log_var = encoder(x[0].to(device))
reconstruction = decoder(z)

In [59]:
NUM_POINTS = 7000
WIDTH = 4
Z_SCALE = 5.0
SPLINE_POLY_DEG = 3
SIN_MULT = 1.5

In [60]:
# , manifold: LatentManifold, n_points: int = NUM_POINTS, width: float = WIDTH
def create_df(points, labels) -> pd.DataFrame:

    x_1 = points[:, 0]
    x_2 = points[:, 1]

    df = pd.DataFrame(np.stack([x_1, x_2, labels], axis=1), columns=["x", "y", "z"])

    # df = pd.DataFrame(data=np.array([x_1, x_2, labels]).transpose(), columns=["x", "y", "z"])

    # df['x']
    #
    return df


def create_background_df(
    manifold: LatentManifold, n_points: int = NUM_POINTS, width: float = WIDTH
) -> pd.DataFrame:

    x = np.random.rand(n_points, 1) * width * 2 - width
    y = np.random.rand(n_points, 1) * width * 2 - width

    df = pd.DataFrame(data=np.hstack([x, y]), columns=["x", "y"])
    df["cluster"] = (df.x > 0) + 2 * (df.y > 0)
    # df["z"] = np.sum(df.apply(lambda r: manifold.metric._mapping(r[:2], r[:2], r[:2]), axis=1))
    return df


def create_latent_fig(df: pd.DataFrame, three_d: bool = False) -> go.Figure:
    if three_d:
        return px.scatter_3d(df, x="x", y="y", z="z", color="z", opacity=1)
    else:
        return px.scatter(df, x="x", y="y", color="z", opacity=0.5)
        # return px.scatter(df, x="x", y="y", opacity=0.5)  # z tym działało

In [61]:
def create_topology_fig_geodesics_given_manifold(
    points,
    labels,
    centers1: list[np.ndarray],
    centers2: list[np.ndarray],
    manifold: LatentManifold,
) -> go.Figure:
    df_ = create_df(points, labels)
    background_trace = create_latent_fig(df_, three_d=False).data[0]

    return create_topology_fig_geodesics(
        centers1,
        centers2,
        manifold,
        background_trace,
    )

In [62]:
def create_topology_fig_given_manifold(
    points,
    labels,
    centre: np.ndarray,
    manifold: LatentManifold,
    num_lines: int,
    num_circles: int,
    line_length: float = 10,
    show_lines: bool = True,
    show_circles: bool = True,
) -> go.Figure:
    # df_ = create_background_df(manifold)
    df_ = create_df(points, labels)
    background_trace = create_latent_fig(df_, three_d=False).data[0]

    return create_topology_fig(
        centre,
        manifold,
        background_trace,
        num_lines,
        num_circles,
        line_length,
        show_lines,
        show_circles,
    )

In [63]:
ambient_metric = EuclideanMetric(1024)
manifold_mnist = LatentManifold(
    TorchModelMapping(decoder, (2,), (1, 1, 32, 32)), ambient_metric
)

batch = next(iter(testloader))
points, mu, log_var = encoder(batch[0].to(device))

points = points.detach().cpu().numpy()
labels = batch[1]

# mu = mu[0]
# log_var = log_var[0]
# x = next(iter(testloader))
# points = encoder(x[0])[0]
# labels = x[1]
center = points[0]
centers_1 = points[:10]
centers_2 = points[10:20]

In [65]:
# 6 points = 2min on gpu
# 2 points = 2min on cpu
create_topology_fig_given_manifold(
    points=points,
    labels=labels,
    centre=center,
    manifold=manifold_mnist,
    num_lines=2,
    num_circles=4,
).show()

KeyboardInterrupt: 

In [66]:
points[0], points[1]

(array([-0.8701538 , -0.22426684], dtype=float32),
 array([ 0.28401083, -1.2436879 ], dtype=float32))

In [68]:
create_topology_fig_geodesics_given_manifold(
    points=points,
    labels=labels,
    centers1=[points[0]],
    centers2=[points[1]],
    manifold=manifold_mnist,
).show()

[-0.8701538  -0.22426684] [ 0.28401083 -1.2436879 ]
[-0.8701538  -0.22426684] [ 0.28401083 -1.2436879 ]
[-0.8701538  -0.22426684] [ 1.1541646 -1.019421 ]
[ 0.28401077 -1.2436879 ] [ 1.1541646 -1.019421 ]
[ 1.15416455 -1.01942098] [-4.59777119  4.73567804]
[ 1.15416455 -1.01942098] [ 1.92695013 -1.87758385]
[ 1.15416455 -1.01942098] [-4.59777119  4.73567804]
[ 1.15416455 -1.01942098] [ 1.92695013 -1.87758385]
[ 0.33857439 -0.19276325] [ 0.12282949 -0.34265077]
[ 1.15416455 -1.01942098] [-4.59777119  4.73567804]
[ 1.15416455 -1.01942098] [ 1.9269496  -1.87758315]
[ 1.15416455 -1.01942098] [-4.59777046  4.73567754]
[ 1.15416455 -1.01942098] [ 1.92695013 -1.87758385]
[ 1.15416458 -1.01942098] [-4.59777144  4.73567822]
[ 1.15416458 -1.01942098] [ 1.92695022 -1.87758391]
[ 1.15416455 -1.01942095] [-4.59777118  4.73567795]
[ 1.15416455 -1.01942095] [ 1.92695011 -1.8775838 ]
[ 0.33857439 -0.19276325] [ 0.1228295  -0.34265087]
[ 0.33857439 -0.19276325] [ 0.12282938 -0.34265072]
[ 0.33857441 -0.

KeyboardInterrupt: 