In [1]:
from gne.models.GeometricEmbedding import GeometricEmbedding
from gne.models.Config import Config
from gne.utils.geometries import Euclidean

from sklearn.datasets import make_swiss_roll

import torch

from umap import UMAP

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
# global parameters
N = 1000

# Load data
points, t = make_swiss_roll(n_samples=N, noise=0.05, random_state=0, hole=True)
points = torch.tensor(points)
t = torch.tensor(t)

In [None]:
# Calculate gNE embedding 

config = Config(
    epochs=10,
    initialization_method='UMAP',
    batch_size=-1,
    k_neighbours=4,
    learning_rate=.8,
    lr_threshold=.5,
    patience=0
    ) # ~ 25min


djinni = GeometricEmbedding(
    source_geometry = Euclidean(sample=points),
    config = config
)

gne_embedding = djinni(plot_loss=True)

In [4]:
# Calculate 2d PCA embedding
_, _, V = torch.pca_lowrank(points)
principal_directions = V[:, :2]
principal_components = torch.matmul(points, principal_directions)
pca_embedding = principal_components

In [5]:
# Calculate UMAP embedding

# umap_model = UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42, n_jobs=1)
umap_model = UMAP(random_state=42, n_jobs=1)
umap_embedding = umap_model.fit_transform(points)

In [None]:
from matplotlib.colors import to_rgba
import colorsys

# Plot all 3 embeddings for direct comparison
fig, axs = plt.subplots(1, 4, figsize=(24, 6))

# Normalize weight for coloring by height
weight = (t - min(t))/(max(t) - min(t))

# Function to adjust lightness of a color
def modify_brightness(color, amount=0.5):
    c = colorsys.rgb_to_hls(*to_rgba(color)[:3])
    return colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2]) + (1,)

# Base colors for each plot

# Ground Truth
axs[0].remove()
ax = fig.add_subplot(1, 4, 1, projection='3d')
ax.set_title('ground truth')
colors = [modify_brightness('grey', amount) for amount in weight]
ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=colors)


# PCA
axs[1].set_title('PCA')
colors = [modify_brightness('red', amount) for amount in weight]
sns.scatterplot(x=pca_embedding[:,0], y=pca_embedding[:,1], ax=axs[1], color=colors)

# UMAP
axs[2].set_title('UMAP')
colors = [modify_brightness('green', amount) for amount in weight]
sns.scatterplot(x=umap_embedding[:, 0], y=umap_embedding[:, 1], ax=axs[2], c=colors)

# gNE
axs[3].set_title('gNE')
colors = [modify_brightness('blue', amount) for amount in weight]
sns.scatterplot(x=gne_embedding[:,0], y=gne_embedding[:,1], ax=axs[3], c=colors)

for j in range(4):
    axs[j].set_aspect('equal')
    axs[j].axis('square')

plt.savefig('reports/figures/swiss-roll.pdf')
plt.show()