In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import plotly.graph_objects as go
import plotly.express as px

import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from latent_geometry.mapping import TorchModelMapping
from latent_geometry.manifold import LatentManifold
from latent_geometry.metric import EuclideanMetric
from latent_geometry.visual.plotly import create_topology_fig

In [3]:
NUM_POINTS = 7000
WIDTH = 4
Z_SCALE = 5.0
SPLINE_POLY_DEG = 3
SIN_MULT = 1.5

In [4]:
class Net1(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = torch.exp(-torch.log(torch.abs(1.5 - y + x**2)) ** 2) * Z_SCALE
        return torch.stack([x, y, z])


class Net2(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = (
            torch.exp(-torch.log(torch.max(1.5 - y + x**2, torch.tensor(1.2))) ** 2)
            * Z_SCALE
        )
        return torch.stack([x, y, z])


class FlatShortNet(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = (
            torch.exp(-torch.log(torch.max(1.5 - y + x**2, torch.tensor(1.2))) ** 2)
            * Z_SCALE
            / 3
        )
        return torch.stack([x, y, z])


class SinCosNet(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = torch.sin(x * SIN_MULT) + torch.cos(y * SIN_MULT)
        return torch.stack([x, y, z])


class SinCosFlatNet(nn.Module):
    def forward(self, in_):
        x, y = in_
        z = torch.max(
            torch.sin(x * SIN_MULT) + torch.cos(y * SIN_MULT), torch.tensor(0)
        )
        return torch.stack([x, y, z])


ambient_metric = EuclideanMetric(3)

manifold1 = LatentManifold(TorchModelMapping(Net1(), (2,), (3,)), ambient_metric)
manifold_flat = LatentManifold(TorchModelMapping(Net2(), (2,), (3,)), ambient_metric)
manifold_flat_short = LatentManifold(
    TorchModelMapping(FlatShortNet(), (2,), (3,)), ambient_metric
)
manifold_sin_cos = LatentManifold(
    TorchModelMapping(SinCosNet(), (2,), (3,)), ambient_metric
)
manifold_sin_cos_flat = LatentManifold(
    TorchModelMapping(SinCosFlatNet(), (2,), (3,)), ambient_metric
)

MANIFOLDS = [manifold1, manifold_flat, manifold_sin_cos, manifold_sin_cos_flat]

In [5]:
def create_background_df(
    manifold: LatentManifold, n_points: int = NUM_POINTS, width: float = WIDTH
) -> pd.DataFrame:
    x = np.random.rand(n_points, 1) * width * 2 - width
    y = np.random.rand(n_points, 1) * width * 2 - width

    df = pd.DataFrame(data=np.hstack([x, y]), columns=["x", "y"])
    df["cluster"] = (df.x > 0) + 2 * (df.y > 0)
    df["z"] = df.apply(lambda r: manifold.metric._mapping(r[:2])[2], axis=1)
    return df


def create_lantent_fig(df: pd.DataFrame, three_d: bool = False) -> go.Figure:
    if three_d:
        return px.scatter_3d(df, x="x", y="y", z="z", color="z", opacity=1)
    else:
        return px.scatter(df, x="x", y="y", color="z", opacity=0.5)

In [6]:
for manifold in MANIFOLDS:
    create_lantent_fig(create_background_df(manifold), three_d=True).show()

## plotly

In [7]:
def create_topology_fig_given_manifold(
    centre: np.ndarray,
    manifold: LatentManifold,
    num_lines: int,
    num_circles: int,
    line_length: float = 2.5,
    show_lines: bool = True,
    show_circles: bool = True,
) -> go.Figure:
    df_ = create_background_df(manifold)
    background_trace = create_lantent_fig(df_, three_d=False).data[0]

    return create_topology_fig(
        centre,
        manifold,
        background_trace,
        num_lines,
        num_circles,
        line_length,
        show_lines,
        show_circles,
    )

## manifold 1

In [8]:
create_topology_fig_given_manifold(
    np.array([-1.0, 1.5]),
    manifold=manifold1,
    num_lines=16,
    num_circles=4,
    line_length=3.0,
    show_circles=False,
).show()

[3.008003453563485, 3.006364539659625, 3.0369408199777723, 3.016141106815157, 3.007026606583331, 3.000797183719385, 3.0069672810553385, 3.1006576707215925, 3.0183141580451567, 3.0430221953108507, 3.036049158531908, 3.014770544895676, 3.007153643090455, 3.0035085269704136, 2.999681364193755, 3.0813372887432737]
[1.667313806152444, 0.4651856010504793, 1.0026309287613344, 2.0088591448267867, 2.7212286786188575, 2.9973271136015316, 2.8028483472799857, 2.2179683399318564, 1.3503943279527721, 0.6002084755892672, 1.1850005680293316, 2.019027053474616, 2.429176791515656, 2.5406680102496506, 2.879644627214505, 2.6504250375262353]


In [9]:
create_topology_fig_given_manifold(
    np.array([-1.0, 1.5]),
    manifold=manifold1,
    num_lines=16,
    num_circles=4,
    line_length=3.0,
).show()

[3.008003453563485, 3.006364539659625, 3.0369408199777723, 3.016141106815157, 3.007026606583331, 3.000797183719385, 3.0069672810553385, 3.1006576707215925, 3.0183141580451567, 3.0430221953108507, 3.036049158531908, 3.014770544895676, 3.007153643090455, 3.0035085269704136, 2.999681364193755, 3.0813372887432737]
[1.667313806152444, 0.4651856010504793, 1.0026309287613344, 2.0088591448267867, 2.7212286786188575, 2.9973271136015316, 2.8028483472799857, 2.2179683399318564, 1.3503943279527721, 0.6002084755892672, 1.1850005680293316, 2.019027053474616, 2.429176791515656, 2.5406680102496506, 2.879644627214505, 2.6504250375262353]


## manifold flat

In [10]:
create_topology_fig_given_manifold(
    np.array([-1.5, 1]),
    manifold=manifold_flat,
    num_lines=15,
    num_circles=6,
    line_length=2.5,
).show()

[2.5004187110209926, 2.509568511651965, 2.509023262796512, 2.505437927022971, 2.5016168490704764, 2.5035605110978096, 2.505489293363612, 2.5054735973650297, 2.5081696766567316, 2.504512889558663, 2.5042959900588775, 2.5006419861215554, 2.500218914785118, 2.5084738655419354, 2.512020923001283]
[0.5186233507000773, 0.45585216620275054, 0.5359250197770414, 0.8468778395411442, 1.8429937919245636, 1.971380738815764, 1.5621997980829907, 1.4789182345085619, 1.4644677982992973, 1.4758801033546556, 1.53229289159542, 1.7147053874715177, 2.489730953617848, 1.3377510620257278, 0.7414456386249909]


In [11]:
create_topology_fig_given_manifold(
    np.array([-1.5, 1]),
    manifold=manifold_flat_short,
    num_lines=15,
    num_circles=6,
    line_length=2.5,
).show()

[2.1782710255211395, 1.859746517230764, 1.9268688489091614, 2.22940216538867, 2.500020511371284, 2.499443805264502, 2.498726246619802, 2.498915188672082, 2.498746198222419, 2.4988390045092497, 2.499186631142089, 2.4999304451433173, 2.5000008230569692, 2.5039260203207667, 2.381326829266348]
[1.632133942716958, 1.274835999989497, 1.3765793958176338, 1.812065399388703, 2.3765689912281025, 2.4359296041245218, 2.3542681186093812, 2.318341587422383, 2.3076734377948513, 2.3145134835096957, 2.341233910120489, 2.4012670938612746, 2.4930935850844973, 2.3156042952531677, 1.9566248045413985]


## manifold sin cos

In [12]:
create_topology_fig_given_manifold(
    np.array([0, 1]),
    manifold=manifold_sin_cos,
    num_lines=20,
    num_circles=4,
    line_length=2.5,
    show_circles=False,
).show()

[2.4997591103311017, 2.5012740816228507, 2.528843150126731, 2.5294664307399017, 2.4989853163602325, 2.5011533238999792, 2.5000269287207413, 2.500353162815248, 2.5007420253780506, 2.5001997404429135, 2.4995845072429397, 2.5005494418813723, 2.499742096899598, 2.516078364591057, 2.4977586089688972, 2.4998555878411945, 2.5004976942991477, 2.50018101760789, 2.50028839152478, 2.5000731669625615]
[1.8415821193367277, 2.1954521971094434, 2.355262517968851, 2.3143569532150723, 2.1775772786130325, 1.7987899361220958, 1.5192051532383823, 1.379011401787809, 1.390250370625625, 1.5455745502408873, 1.8270604907655754, 2.196965865887455, 2.2836008267436254, 2.3408987068270473, 2.213496932391696, 1.8709679638864354, 1.613654118728235, 1.4736100134570618, 1.463575907308549, 1.588074424758486]


In [13]:
create_topology_fig_given_manifold(
    np.array([0, 1]),
    manifold=manifold_sin_cos_flat,
    num_lines=10,
    num_circles=4,
    line_length=3,
    show_circles=False,
).show()

[2.9987413060755554, 2.910848351591342, 2.544093265076157, 1.4513147237990744, 1.3442546602453702, 1.7057313624402843, 3.039678469508104, 2.9966748461043675, 3.001554387656973, 3.001385473032745]
[2.330810436245156, 2.798148616734591, 2.137837073932706, 1.4034437765506884, 1.2930487853870898, 1.665665225915345, 2.79285319390505, 2.7024197369871623, 2.0882483707199997, 1.9027209377102696]
