In [1]:
from gne.models.GeometricEmbedding import GeometricEmbedding, Config
from gne.utils.geometries import Euclidean
from gne.utils.complex import Simplex
from gne.utils.geometry_to_complex import kneighbors_complex

import torch
import matplotlib.pyplot as plt
import itertools

from umap import UMAP

In [2]:
# Global Variables
n_points = 50
k_neighbours = 4
batch_size = -1

# Create Data
# 3d Double Helix
winding_number = 2
separation = 1 / 4

t = torch.linspace(0, winding_number, n_points)
x = torch.sin(2 * torch.pi * t)
y = torch.cos(2 * torch.pi * t)

z1 = t
z2 = t + separation

# Stack points for both helices
points_helix_1 = torch.stack((x, y, z1), dim=1)
points_helix_2 = torch.stack((x, y, z2), dim=1)

# Combine the two helices
points = torch.cat((points_helix_1, points_helix_2), dim=0)

In [3]:
# Calculate 2d projection of points into xz-plane
projection = points[:, [0, 2]]

In [4]:
# Calculate 2d PCA representation
_, _, V = torch.pca_lowrank(points)
principal_directions = V[:, :2]
principal_components = torch.matmul(points, principal_directions)
pca_embedding = principal_components

# Create kNN Complex
pca_geometry = Euclidean(sample=pca_embedding)
pca_complex = kneighbors_complex(pca_geometry, k_neighbours)

In [5]:
# Calculate umap Embedding
umap_model = UMAP(n_neighbors=10, min_dist=0.1, n_components=2, random_state=42, n_jobs=1)
umap_embedding = umap_model.fit_transform(points)
umap_embedding = torch.tensor(umap_embedding)

# Create kNN Complex
umap_geometry = Euclidean(sample=umap_embedding)
umap_complex = kneighbors_complex(umap_geometry, k_neighbours)

In [None]:
# Initialize gNE Embedding

# set configurations
config=Config(
    epochs=50,
    k_neighbours=k_neighbours, 
    batch_size=batch_size,
    initialization_method="PCA",
    learning_rate=1,
    # burnin=0,
)

# Initialize GeometricEmbedding instance
djinni = GeometricEmbedding(
    source_geometry=Euclidean(sample=points),
    target_geometry=Euclidean(dimension=2),
    config=config,
)

# get initial embedding for plotting
gne_embedding = djinni.target_geometry.sample.detach()

# create source_complex and target_complex for plotting
djinni.source_complex = kneighbors_complex(djinni.source_geometry, k_neighbours)
djinni.target_complex = kneighbors_complex(djinni.target_geometry, k_neighbours)

In [None]:
# Plot initial configuration and projection of ground truth
# including edges of kNN corresponding complexes

fig, ax = plt.subplots(1, 2, figsize=(16, 6))

# Plot projection
ax[0].scatter(projection[:, 0], projection[:, 1], c='black')
source_edges = [s for s in djinni.source_complex if s.weight is not None]
for (i, p), (j, q) in itertools.combinations(enumerate(projection), 2):
    if Simplex([i, j]) in source_edges:
        ax[0].plot([p[0], q[0]], [p[1], q[1]], color='black')
for index, (x, y) in enumerate(projection):
    ax[0].annotate(str(index), (x, y), color='black', fontsize=10)

# Plot gNE embedding
ax[1].scatter(gne_embedding[:, 0], gne_embedding[:, 1], c='blue')
target_edges = [s for s in djinni.target_complex if s.weight is not None]
for (i, p), (j, q) in itertools.combinations(enumerate(gne_embedding), 2):
    if Simplex([i, j]) in target_edges:
        if Simplex([i, j]) in source_edges:
            ax[1].plot([p[0], q[0]], [p[1], q[1]], color='blue')
        else:
            ax[1].plot([p[0], q[0]], [p[1], q[1]], color='red')
    elif Simplex([i, j]) in source_edges:
        ax[1].plot([p[0], q[0]], [p[1], q[1]], linestyle='dashed', color='black')
for index, (x, y) in enumerate(gne_embedding):
    ax[1].annotate(str(index), (x, y), color='blue', fontsize=10)

for j in range(2):
    ax[j].set_aspect('equal')
    ax[j].axis('square')

plt.show()

In [None]:
# Run Optimization (using a workaround for plotting intermediate configurations)
# ~ 8s/epoch

for iteration in range(3):
    
    gne_embedding = djinni(plot_loss=True)

    # Plot current configuration
    fig, ax = plt.subplots(figsize=(8, 6))

    # Plot gNE embedding
    ax.scatter(gne_embedding[:, 0], gne_embedding[:, 1], c='blue')
    source_edges = [s for s in djinni.source_complex if s.weight is not None]
    target_edges = [s for s in djinni.target_complex if s.weight is not None]
    for (i, p), (j, q) in itertools.combinations(enumerate(gne_embedding), 2):
        if Simplex([i, j]) in target_edges:
            if Simplex([i, j]) in source_edges:
                ax.plot([p[0], q[0]], [p[1], q[1]], color='blue')
            else:
                ax.plot([p[0], q[0]], [p[1], q[1]], color='red')
        elif Simplex([i, j]) in source_edges:
            ax.plot([p[0], q[0]], [p[1], q[1]], linestyle='dashed', color='black')

    for index, (x, y) in enumerate(gne_embedding):
        ax.annotate(str(index), (x, y), color='blue', fontsize=10)

    ax.set_aspect('equal')
    ax.axis('square')

    plt.show()

In [None]:
# Plot all projection and all 3 embeddings for direct comparison

fig, ax = plt.subplots(1, 4, figsize=(24, 6))

# Projection
ax[0].set_title('Projection')
ax[0].scatter(projection[:, 0], projection[:, 1], c='black')
source_edges = [s for s in djinni.source_complex if s.weight is not None]
for (i, p), (j, q) in itertools.combinations(enumerate(projection), 2):
    if Simplex([i, j]) in source_edges:
        ax[0].plot([p[0], q[0]], [p[1], q[1]], color='black')
for index, (x, y) in enumerate(projection):
    ax[0].annotate(str(index), (x, y), color='black', fontsize=10)

# PCA
ax[1].set_title('PCA')
ax[1].scatter(pca_embedding[:, 0], pca_embedding[:, 1], c='red')
pca_edges = [s for s in pca_complex if s.weight is not None]
for (i, p), (j, q) in itertools.combinations(enumerate(pca_embedding), 2):
    if Simplex([i, j]) in pca_edges:
        if Simplex([i, j]) in source_edges:
            ax[1].plot([p[0], q[0]], [p[1], q[1]], color='red')
        else:
            ax[1].plot([p[0], q[0]], [p[1], q[1]], color='gray')
    elif Simplex([i, j]) in source_edges:
        ax[1].plot([p[0], q[0]], [p[1], q[1]], linestyle='dashed', color='gray')
for index, (x, y) in enumerate(pca_embedding):
    ax[1].annotate(str(index), (x, y), color='red', fontsize=10)

# UMAP
ax[2].set_title('UMAP')
ax[2].scatter(umap_embedding[:, 0], umap_embedding[:, 1], c='green')
umap_edges = [s for s in umap_complex if s.weight is not None]
for (i, p), (j, q) in itertools.combinations(enumerate(umap_embedding), 2):
    if Simplex([i, j]) in umap_edges:
        if Simplex([i, j]) in source_edges:
            ax[2].plot([p[0], q[0]], [p[1], q[1]], color='green')
        else:
            ax[2].plot([p[0], q[0]], [p[1], q[1]], color='gray')
    elif Simplex([i, j]) in source_edges:
        ax[2].plot([p[0], q[0]], [p[1], q[1]], linestyle='dashed', color='gray')
for index, (x, y) in enumerate(umap_embedding):
    ax[2].annotate(str(index), (x, y), color='green', fontsize=10)

# gNE
ax[3].set_title('gNE')
ax[3].scatter(gne_embedding[:, 0], gne_embedding[:, 1], c='blue')
gne_edges = [s for s in djinni.target_complex if s.weight is not None]
for (i, p), (j, q) in itertools.combinations(enumerate(gne_embedding), 2):
    if Simplex([i, j]) in gne_edges:
        if Simplex([i, j]) in source_edges:
            ax[3].plot([p[0], q[0]], [p[1], q[1]], color='blue')
        else:
            ax[3].plot([p[0], q[0]], [p[1], q[1]], color='gray')
    elif Simplex([i, j]) in source_edges:
        ax[3].plot([p[0], q[0]], [p[1], q[1]], linestyle='dashed', color='gray')
for index, (x, y) in enumerate(gne_embedding):
    ax[3].annotate(str(index), (x, y), color='blue', fontsize=10)

for j in range(4):
    ax[j].set_aspect('equal')
    ax[j].axis('square')

plt.savefig('reports/figures/double-helix.pdf')
plt.show()

In [10]:
#TODO: Also run gNE with random and UMAP initializations, as well as projection as initial configuration, for better comparisons