In [None]:
from gne.models.GeometricEmbedding import GeometricEmbedding
from gne.models.Config import Config
from gne.utils.geometries import Euclidean

import polars as pl
import torch

from umap import UMAP

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Load data
bunny = torch.tensor(pl.read_csv("data/bunny.csv", separator=" ").to_numpy())

In [None]:
# Calculate gNE embedding 
# ca. 45s/epoch with given settings
# typical number of epochs is ~50 --> ~75min

djinni = GeometricEmbedding(
    source_geometry = Euclidean(sample=bunny),
    target_geometry = Euclidean(dimension=2),
    config = Config(epochs=100, learning_rate=.1, patience=10, cooldown=1)
)

gne_embedding = djinni(plot_loss=True)

# NB: while we need to specify target geometry and its dimension, 
# the dimension of the source geometry is infered from the point sample

In [None]:
# Calculate 2d PCA embedding
_, _, V = torch.pca_lowrank(bunny)
principal_directions = V[:, :2]
principal_components = torch.matmul(bunny, principal_directions)
pca_embedding = principal_components

In [None]:
# Calculate UMAP embedding

# umap_model = UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42, n_jobs=1)
umap_model = UMAP(random_state=42, n_jobs=1)
umap_embedding = umap_model.fit_transform(bunny)

In [None]:
from matplotlib.colors import to_rgba
import colorsys

# Plot all 3 embeddings for direct comparison
fig, ax = plt.subplots(1, 4, figsize=(24, 6))

# Normalize weight for coloring by height
height = bunny[:,1]
weight = 1-(height - min(height))/(max(height) - min(height))

# Function to adjust lightness of a color
def modify_brightness(color, amount=0.5):
    c = colorsys.rgb_to_hls(*to_rgba(color)[:3])
    return colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2]) + (1,)

# Base colors for each plot

# 2d Projection
ax[0].set_title('2d projection')
colors = [modify_brightness('grey', amount) for amount in weight]
sns.scatterplot(x=bunny[:,0], y=bunny[:,1], ax=ax[0], c=colors)

# PCA
ax[1].set_title('PCA')
colors = [modify_brightness('red', amount) for amount in weight]
sns.scatterplot(x=pca_embedding[:,0], y=pca_embedding[:,1], ax=ax[1], color=colors)

# UMAP
ax[2].set_title('UMAP')
colors = [modify_brightness('green', amount) for amount in weight]
sns.scatterplot(x=umap_embedding[:, 0], y=umap_embedding[:, 1], ax=ax[2], c=colors)

# gNE
ax[3].set_title('gNE')
colors = [modify_brightness('blue', amount) for amount in weight]
sns.scatterplot(x=gne_embedding[:,0], y=gne_embedding[:,1], ax=ax[3], c=colors)

for j in range(4):
    ax[j].set_aspect('equal')
    ax[j].axis('square')

plt.savefig('reports/figures/bunny.pdf')
plt.show()