In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import plotly.graph_objects as go
import plotly.express as px

import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from latent_geometry.mapping import TorchModelMapping
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric
from latent_geometry.visual.plotly import create_topology_fig

In [3]:
NUM_POINTS = 7000
WIDTH = 4
Z_SCALE = 5.0
SPLINE_POLY_DEG = 3
SIN_MULT = 1.5

In [4]:
class Net1(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = torch.exp(-torch.log(torch.abs(1.5 - y + x**2)) ** 2) * Z_SCALE
        return torch.stack([x, y, z])


class Net2(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = (
            torch.exp(-torch.log(torch.max(1.5 - y + x**2, torch.tensor(1.2))) ** 2)
            * Z_SCALE
        )
        return torch.stack([x, y, z])


class FlatShortNet(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = (
            torch.exp(-torch.log(torch.max(1.5 - y + x**2, torch.tensor(1.2))) ** 2)
            * Z_SCALE
            / 3
        )
        return torch.stack([x, y, z])


class SinCosNet(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = torch.sin(x * SIN_MULT) + torch.cos(y * SIN_MULT)
        return torch.stack([x, y, z])


class SinCosFlatNet(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = torch.max(
            torch.sin(x * SIN_MULT) + torch.cos(y * SIN_MULT), torch.tensor(0)
        )
        return torch.stack([x, y, z])


ambient_metric = EuclideanMetric(3)

manifold1 = LatentManifold(TorchModelMapping(Net1(), (2,), (3,)), ambient_metric)
manifold_flat = LatentManifold(TorchModelMapping(Net2(), (2,), (3,)), ambient_metric)
manifold_flat_short = LatentManifold(
    TorchModelMapping(FlatShortNet(), (2,), (3,)), ambient_metric
)
manifold_sin_cos = LatentManifold(
    TorchModelMapping(SinCosNet(), (2,), (3,)), ambient_metric
)
manifold_sin_cos_flat = LatentManifold(
    TorchModelMapping(SinCosFlatNet(), (2,), (3,)), ambient_metric
)

MANIFOLDS = [manifold1, manifold_flat, manifold_sin_cos, manifold_sin_cos_flat]

In [5]:
def create_background_df(
    manifold: LatentManifold, n_points: int = NUM_POINTS, width: float = WIDTH
) -> pd.DataFrame:
    x = np.random.rand(n_points, 1) * width * 2 - width
    y = np.random.rand(n_points, 1) * width * 2 - width

    df = pd.DataFrame(data=np.hstack([x, y]), columns=["x", "y"])
    df["cluster"] = (df.x > 0) + 2 * (df.y > 0)
    df["z"] = df.apply(lambda r: manifold.metric._mapping(r[:2])[2], axis=1)
    return df


def create_lantent_fig(df: pd.DataFrame, three_d: bool = False) -> go.Figure:
    if three_d:
        return px.scatter_3d(df, x="x", y="y", z="z", color="z", opacity=1)
    else:
        return px.scatter(df, x="x", y="y", color="z", opacity=0.5)

In [6]:
for manifold in MANIFOLDS:
    create_lantent_fig(create_background_df(manifold), three_d=True).show()

## plotly

In [7]:
def create_topology_fig_given_manifold(
    centre: np.ndarray,
    manifold: LatentManifold,
    num_lines: int,
    num_circles: int,
    line_length: float = 2.5,
    show_lines: bool = True,
    show_circles: bool = True,
) -> go.Figure:
    df_ = create_background_df(manifold)
    background_trace = create_lantent_fig(df_, three_d=False).data[0]

    return create_topology_fig(
        centre,
        manifold,
        background_trace,
        num_lines,
        num_circles,
        line_length,
        show_lines,
        show_circles,
    )

## manifold 1

In [8]:
create_topology_fig_given_manifold(
    np.array([-1.0, 1.5]),
    manifold=manifold1,
    num_lines=16,
    num_circles=4,
    line_length=3.0,
    show_circles=False,
).show()

In [9]:
create_topology_fig_given_manifold(
    np.array([-1.0, 1.5]),
    manifold=manifold1,
    num_lines=16,
    num_circles=4,
    line_length=3.0,
).show()

## manifold flat

In [10]:
create_topology_fig_given_manifold(
    np.array([-1.5, 1]),
    manifold=manifold_flat,
    num_lines=15,
    num_circles=6,
    line_length=2.5,
).show()

In [11]:
create_topology_fig_given_manifold(
    np.array([-1.5, 1]),
    manifold=manifold_flat_short,
    num_lines=15,
    num_circles=6,
    line_length=2.5,
).show()

## manifold sin cos

In [12]:
create_topology_fig_given_manifold(
    np.array([0, 1]),
    manifold=manifold_sin_cos,
    num_lines=20,
    num_circles=4,
    line_length=2.5,
    show_circles=False,
).show()

In [13]:
create_topology_fig_given_manifold(
    np.array([0, 1]),
    manifold=manifold_sin_cos_flat,
    num_lines=10,
    num_circles=4,
    line_length=3,
    show_circles=False,
).show()