NB! Umap installation required. Type: 'pip install umap-learn'.

This notebook visualises the Swissroll dataset and compares its embedding into a pre-trained AE latent space to standard dimensionality reduction techniques such as:

0) PCA https://pytorch.org/docs/stable/generated/torch.pca_lowrank.html
1) LLE https://cs.nyu.edu/~roweis/lle/papers/lleintroa4.pdf
2) t-SNE https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding
3) UMAP https://umap-learn.readthedocs.io/en/latest/

In [None]:
# Minimal imports
import math
import torch
import matplotlib.pyplot as plt
import ricci_regularization
import yaml
from sklearn import datasets
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn import datasets, manifold
import torch
import math
import torchvision
import numpy as np

In [None]:
Path_pictures = f"../../plots"
violent_saving = False
alpha = 0.5 # point opacity

loading the dataset and tuned AE

In [None]:
with open('../../experiments/MNIST_without_curvature_regularization_config.yaml', 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"]
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

torus_ae = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config, additional_path="../")

print("AE weights loaded successfully.")
experiment_name = yaml_config["experiment"]["name"]
curv_w = yaml_config["loss_settings"]["lambda_curv"]

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name in ["MNIST01", "Synthetic"]:
    # k from the JSON configuration file is the number of classes
    selected_labels = yaml_config["dataset"]["selected_labels"]
    k = len ( selected_labels )
elif dataset_name == "MNIST":
    k = 10
print("Experiment name:", experiment_name)

In [None]:
# choose train or test loader
loader = test_loader
#loader = train_loader

torus_ae.cpu()
colorlist = []
enc_list = []
input_dataset_list = []
recon_dataset_list = []

for (data, labels) in tqdm( loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

input_dataset = torch.cat(input_dataset_list).reshape(-1, D)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

# latent \in [-1,1]. grid reparametrization for plotting
encoded_points_no_grad = encoded_points_no_grad/math.pi

# PCA

In [None]:
u,s,v = torch.pca_lowrank(torch.tensor(input_dataset),q=2)

In [None]:
plt.figure(figsize=(9,9),dpi=400)
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 20
plt.scatter( u[:,0], u[:,1], c=color_array, s= 40,alpha=alpha, cmap='jet',marker='o',edgecolors=None )
#plt.title( "PCA embedding of the swiss roll")
#plt.colorbar(orientation='vertical',shrink = 0.7)
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/{dataset_name}_pca.pdf',bbox_inches='tight',format='pdf')
plt.show()

# LLE 

In [None]:
sr_lle, sr_err = manifold.locally_linear_embedding(
    input_dataset, n_neighbors=12, n_components=2
)
"""
fig, axs = plt.subplots(figsize=(8, 8), nrows=2)
axs[0].scatter(sr_lle[:, 0], sr_lle[:, 1], c=color_array)
axs[0].set_title("LLE Embedding of Swiss Roll")
axs[1].scatter(sr_tsne[:, 0], sr_tsne[:, 1], c=color_array)
_ = axs[1].set_title("t-SNE Embedding of Swiss Roll")
"""


In [None]:
plt.rcParams.update({'font.size': 20})
fig = plt.figure(figsize=(9,9),dpi=400)
plt.scatter(sr_lle[:, 0], sr_lle[:, 1], c=color_array,cmap='jet',s=40,alpha=alpha)
#plt.title("LLE Embedding of the swiss roll")
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/{dataset_name}_lle.pdf',bbox_inches='tight',format='pdf')
plt.show()

# t-SNE

In [None]:
sr_tsne = manifold.TSNE(n_components=2, perplexity=40, random_state=0).fit_transform(
    input_dataset
)

In [None]:
fig = plt.figure(figsize=(9,9),dpi=400)
plt.rcParams.update({'font.size': 20})
plt.scatter(sr_tsne[:, 0], sr_tsne[:, 1], c=color_array,cmap='jet',s=40,alpha=alpha)
#plt.title("t-SNE embedding of the swiss roll")
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/{dataset_name}_tsne.pdf',bbox_inches='tight',format='pdf')
#plt.savefig(f'{Path_pictures}/swissroll_tsne.pdf',bbox_inches='tight',format='pdf')
plt.show()

# UMAP

In [None]:
import umap

In [None]:
mapper = umap.UMAP().fit(input_dataset)

In [None]:
encoded_points = mapper.embedding_

In [None]:
fig = plt.figure(figsize=(9,9),dpi=400)
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 20
plt.scatter( encoded_points[:,0], encoded_points[:,1], c=color_array, s= 40,alpha=alpha, cmap='jet',marker='o',edgecolors=None )
#plt.title( "UMAP embedding of the swiss roll")
#plt.colorbar(orientation='vertical',shrink = 0.7)
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/{dataset_name}_umap.pdf',bbox_inches='tight',format='pdf')
plt.show()

# AE latent space

In [None]:
plt.rcParams.update({'font.size': 20})
plt.figure(figsize=(9, 9),dpi=400)

if dataset_name == "Swissroll":
    plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o',s=40,alpha=alpha, edgecolor='none', cmap= 'jet')
else:
    plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap=ricci_regularization.discrete_cmap(k, 'jet'))
    plt.colorbar(ticks=range(k))
plt.xticks([-1.,-0.5,0.,0.5,1.])
plt.yticks([-1.,-0.5,0.,0.5,1.])
plt.ylim(-1., 1.)
plt.xlim(-1., 1.)
#plt.grid(True)
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/{dataset_name}_not_regularized_Torus_AE_latent_space.pdf',format="pdf",bbox_inches='tight')
#plt.savefig(f"{Path_pictures}/latent_space_{experiment_name}.jpg",bbox_inches='tight', format="pdf")
plt.show()

# Manifold plot REDO THIS

In [None]:
def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [None]:
#Let us take a uniform grid on the latent space. Note that here d=2. 
numsteps = 10
xs = torch.linspace(-torch.pi, torch.pi, steps = numsteps)
ys = torch.linspace(-torch.pi, torch.pi, steps = numsteps)
uniform_grid = torch.cartesian_prod(xs,ys)

# True Manifold plot
truegrid = torch.cartesian_prod(ys,- xs)
truegrid = - truegrid.roll(1,1)

#img_recon = torus_ae.decoder_torus(torch.rand(100,2)).reshape(-1,1,28,28)
img_recon = torus_ae.decoder_torus(truegrid).reshape(-1,1,28,28)
fig, ax  = plt.subplots(figsize=(20, 20),dpi=400)
ax.set_xticklabels([]) #no tick labels
ax.set_yticklabels([])

img_grid = torchvision.utils.make_grid(img_recon[:100],10,10)
show_image(img_grid.detach())
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/{dataset_name}_manifold_plot.pdf',format="pdf",bbox_inches='tight')
plt.show()