In [10]:
from gne.models.GeometricEmbedding import GeometricEmbedding, Config
from gne.utils.geometries import Euclidean
from gne.utils.complex import Simplex
from gne.utils.geometry_to_complex import kneighbors_complex

import torch
import matplotlib.pyplot as plt
import itertools

from umap import UMAP

In [11]:
# Global Variables for the experiment
k_neighbours = 4

In [12]:
# Create Data
n_points = 10

# 2d circle w/ Gaussian noise
radius = 1
std = .1 # std of the Gaussian noise

angles = torch.linspace(0, 2 * torch.pi, n_points+1)[:-1]
x = radius * torch.cos(angles)
y = radius * torch.sin(angles)
points = torch.stack((x, y), dim=1)
points = points + torch.normal(mean=0, std=std, size=points.size())

In [13]:
# Calculate 2d PCA representation
_, _, V = torch.pca_lowrank(points)
principal_directions = V[:, :2]
principal_components = torch.matmul(points, principal_directions)
pca_embedding = principal_components

# Create kNN complex
pca_geometry = Euclidean(sample=pca_embedding)
pca_complex = kneighbors_complex(pca_geometry, k_neighbours)

In [None]:
# Calculate umap Embedding
umap_model = UMAP(n_neighbors=10, min_dist=0.1, n_components=2, random_state=42, n_jobs=1)
umap_embedding = umap_model.fit_transform(points)
umap_embedding = torch.tensor(umap_embedding)

# Create kNN complex
umap_geometry = Euclidean(sample=umap_embedding)
umap_complex = kneighbors_complex(umap_geometry, k_neighbours)

In [None]:
# Initialize gNE Embedding

# set configurations
config=Config(
    target_geometry={'initialization_method': 'random'},
    training={'epochs':1, 'learning_rate': 0.9, 'burnin': 0},
    k_neighbours=k_neighbours,
)

# Initialize GeometricEmbedding instance
djinni = GeometricEmbedding(
    source_geometry=Euclidean(sample=points),
    target_geometry=Euclidean(dimension=2),
    config=config,
)

# get initial embedding for plotting
gne_embedding = djinni.target_geometry.sample.detach()

# create source_complex and target_complex for plotting
djinni.source_complex = kneighbors_complex(djinni.source_geometry, k_neighbours)
djinni.target_complex = kneighbors_complex(djinni.target_geometry, k_neighbours)

djinni

In [None]:
# Plot initial configuration compared to ground truth (with skeleton of kNN-Complex)

fig, ax = plt.subplots(1,2, figsize=(16, 6))

ax[0].scatter(points[:, 0], points[:, 1], c='black')
for (i, p), (j, q) in itertools.combinations(enumerate(points), 2):
    if Simplex([i,j]) in [s for s in djinni.source_complex if s.weight is not None]:
        ax[0].plot([p[0],q[0]], [p[1], q[1]], color='black')
for index, (x, y) in enumerate(points):
    ax[0].annotate(str(index), (x, y), color='black', fontsize=10)

ax[1].scatter(gne_embedding[:, 0], gne_embedding[:, 1], c='blue')
for (i, p), (j, q) in itertools.combinations(enumerate(gne_embedding), 2):
    if Simplex([i,j]) in [s for s in djinni.source_complex if s.weight is not None]:
        ax[1].plot([p[0],q[0]], [p[1], q[1]], linestyle='dashed', color='grey')
    if Simplex([i,j]) in [s for s in djinni.target_complex if s.weight is not None]:
        ax[1].plot([p[0],q[0]], [p[1], q[1]], color='blue')
for index, (x, y) in enumerate(gne_embedding):
    ax[1].annotate(str(index), (x, y), color='blue', fontsize=10)
    
for j in range(2):
    ax[j].set_aspect('equal')
    ax[j].axis('square')

plt.show()

In [None]:
# Run Optimization (using a workaround for plotting intermediate configurations)
# ~ 2 seconds/epoch

# Prepare plot
number_of_iterations = 25

ncols = 5
nrows = number_of_iterations//ncols+1*(number_of_iterations%ncols!=0)

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4))

for iteration in range(number_of_iterations):
    
    gne_embedding = djinni(plot_loss=False)

    ax = axs[iteration//ncols, iteration%ncols]

    ax.scatter(gne_embedding[:, 0], gne_embedding[:, 1], c='blue')
    for (i, p), (j, q) in itertools.combinations(enumerate(gne_embedding), 2):
        if Simplex([i,j]) in [s for s in djinni.source_complex if s.weight is not None]:
            ax.plot([p[0],q[0]], [p[1], q[1]], linestyle='dashed', color='grey')
        if Simplex([i,j]) in [s for s in djinni.target_complex if s.weight is not None]:
            ax.plot([p[0],q[0]], [p[1], q[1]], color='blue')
    for index, (x, y) in enumerate(gne_embedding):
        ax.annotate(str(index), (x, y), color='blue', fontsize=10)
        
    ax.set_aspect('equal')
    ax.axis('square')
    ax.set_title(f'Epoch {iteration+1}')

plt.savefig('reports/figures/noisy-circle-2d-evolution.pdf')
plt.show()

In [None]:
# Plot all 3 embeddings for direct comparison

fig, ax = plt.subplots(1, 4, figsize=(24, 6))

# ground truth
ax[0].set_title('ground truth')
ax[0].scatter(points[:, 0], points[:, 1], c='black')
for (i, p), (j, q) in itertools.combinations(enumerate(points), 2):
    if Simplex([i,j]) in [s for s in djinni.source_complex if s.weight is not None]:
        ax[0].plot([p[0],q[0]], [p[1], q[1]], color='black')
for index, (x, y) in enumerate(points):
    ax[0].annotate(str(index), (x, y), color='black', fontsize=8)


# PCA
ax[1].set_title('PCA')
ax[1].scatter(pca_embedding[:, 0], pca_embedding[:, 1], c='red')
for (i, p), (j, q) in itertools.combinations(enumerate(pca_embedding), 2):
    if Simplex([i,j]) in [s for s in djinni.source_complex if s.weight is not None]:
        ax[1].plot([p[0],q[0]], [p[1], q[1]], linestyle='dashed', color='grey')
    if Simplex([i,j]) in [s for s in pca_complex if s.weight is not None]:
        ax[1].plot([p[0],q[0]], [p[1], q[1]], color='red')
for index, (x, y) in enumerate(pca_embedding):
    ax[1].annotate(str(index), (x, y), color='red', fontsize=8)

# UMAP
ax[2].set_title('UMAP')
ax[2].scatter(umap_embedding[:, 0], umap_embedding[:, 1], c='green')
for (i, p), (j, q) in itertools.combinations(enumerate(umap_embedding), 2):
    if Simplex([i,j]) in [s for s in djinni.source_complex if s.weight is not None]:
        ax[2].plot([p[0],q[0]], [p[1], q[1]], linestyle='dashed', color='grey')
    if Simplex([i,j]) in [s for s in umap_complex if s.weight is not None]:
        ax[2].plot([p[0],q[0]], [p[1], q[1]], color='green')
for index, (x, y) in enumerate(umap_embedding):
    ax[2].annotate(str(index), (x, y), color='green', fontsize=8)

# gNE
ax[3].set_title('gNE')
ax[3].scatter(gne_embedding[:, 0], gne_embedding[:, 1], c='blue')
for (i, p), (j, q) in itertools.combinations(enumerate(gne_embedding), 2):
    if Simplex([i,j]) in [s for s in djinni.source_complex if s.weight is not None]:
        ax[3].plot([p[0],q[0]], [p[1], q[1]], linestyle='dashed', color='grey')
    if Simplex([i,j]) in [s for s in djinni.target_complex if s.weight is not None]:
        ax[3].plot([p[0],q[0]], [p[1], q[1]], color='blue')
for index, (x, y) in enumerate(gne_embedding):
    ax[3].annotate(str(index), (x, y), color='blue', fontsize=8)

for j in range(4):
    ax[j].set_aspect('equal')
    ax[j].axis('square')

plt.savefig('reports/figures/noisy-circle-2d-comparison.pdf')
plt.show()