NB! Geomstats package is required.

The latent space of the AE is topologically a $d-$ dimensional torus $\mathcal{T}^d$, i.e. it can be considered as a periodic box $[-\pi, \pi]^d$. 

The notebook includes 3 K-means clustering applications.
1) For input data (uses skit.learn K-means see https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html)
2) For output data of the AE (also uses skit.learn K-means)
3) The torus latent space of the AE with Euclidean metric (uses geomstats package, see https://geomstats.github.io/notebooks/07_practical_methods__riemannian_kmeans.html#)

In this notebook data is the part of MNIST dataset with 2 selected labels (5 and 8).

F-scores (see https://en.wikipedia.org/wiki/F-score) of clusterizations vs ground truth labels are comuted. The efficiency of clusterization are computed.

The contents of the notebook are:

1) Setting hyperparameters, dataset loading, plotting embedded data for a pre-trained AE.
2) Geomstats K-means: Euclidean metric on torus latent space
3) K-means in input data space
4) K-means in a output data space
5) F-scores comparison

In [None]:
# prerequisites
%matplotlib inline
import sklearn
from sklearn.cluster import KMeans
from sklearn import datasets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

import numpy as np
from tqdm.notebook import tqdm


# 1. Setting hyperparameters, dataset loading, plotting embedded data for a pre-trained AE.

In [None]:
# json file name
experiment_json = f'../experiments/MNIST01_torus_AEexp8.json'

violent_saving = True # if False it will not save plots
build_report = True

# Loading JSON file
import json
with open(experiment_json) as json_file:
    json_config = json.load(json_file)

print( json.dumps(json_config, indent=2 ) )

Path_experiments = json_config["Path_experiments"]
experiment_name = json_config["experiment_name"]
experiment_number = json_config["experiment_number"]
Path_pictures = json_config["Path_pictures"]

# # Number of workers in DataLoader
# num_workers = 10

In [None]:
dataset_name    = json_config["dataset"]["name"]
split_ratio = json_config["optimization_parameters"]["split_ratio"]
batch_size  = json_config["optimization_parameters"]["batch_size"]

## Dataset uploading 

In [None]:
# import sys
# sys.path.append('../') # have to go 1 level up
import ricci_regularization

In [None]:
if dataset_name == "MNIST":
    #MNIST_SIZE = 28
    # MNIST Dataset
    D = 784
    train_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)
elif dataset_name == "MNIST01":
    D = 784
    full_mnist_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)
    mask = (full_mnist_dataset.targets == -1) 
    selected_labels = json_config["dataset"]["selected_labels"]
    for label in selected_labels:
        mask = mask | (full_mnist_dataset.targets == label)
    indices01 = torch.where(mask)[0]
    
    from torch.utils.data import Subset
    train_dataset = Subset(full_mnist_dataset, indices01) # MNIST only with 0,1 indices

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [m-int(m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

## AE structure

In [None]:
latent_dim = json_config["architecture"]["latent_dim"]
input_dim  = json_config["architecture"]["input_dim"]
architecture_type = json_config["architecture"]["name"]

if architecture_type== "TorusAE":
    torus_ae   = ricci_regularization.Architectures.TorusAE(x_dim=input_dim, h_dim1= 512, h_dim2=256, z_dim=latent_dim)
elif architecture_type =="TorusConvAE":
    torus_ae   = ricci_regularization.Architectures.TorusConvAE(x_dim=input_dim, h_dim1= 512, h_dim2=256, z_dim=latent_dim,pixels=28)
if torch.cuda.is_available():
    torus_ae.cuda()
else:
    torus_ae.cpu()

### Loading the saved weights

In [None]:
# NO! Use the path ../experiments/<Your experiment>/nn_weights/
PATH_ae_wights = json_config["weights_saved_at"]
torus_ae.load_state_dict(torch.load(PATH_ae_wights))
torus_ae.eval()

In [None]:
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)

## Torus latent space

In [None]:
#Classes
N = json_config["dataset"]["parameters"]["k"]

In [None]:
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
feature_space_encoding_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1,D)))
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()
#assert torch.equal(enc,enc_tensor)

## Plots

In [None]:
plt.figure(figsize=(8, 6))
plt.title("Latent space colored by ground truth labels")
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.colorbar(ticks=range(N))
plt.grid(True)
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/latent_space.pdf",format="pdf")

# 2. Geomstats K-means: Euclidean metric on torus latent space

In [None]:
#this adds an environmental variable
#%env GEOMSTATS_BACKEND=pytorch

import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.kmeans import RiemannianKMeans

In [None]:
circumference1 = Hypersphere(dim=1)
circumference2 = Hypersphere(dim=1)

## Building torus as a product $\mathcal{T} = \mathcal{S}^1 \times \mathcal{S}^1$ 

In [None]:
from geomstats.geometry.product_manifold import ProductManifold
torus = ProductManifold((circumference1,circumference2))

Loading saved points and labels and plotting ground truth labels

In [None]:
encoded_angles = encoded_points_no_grad #torch.load("encoded_angles.pt")
gt_labels = color_array #torch.load("labels.pt")
#convert dt_labels into 0 and 1 array
gt_labels = (gt_labels - min(gt_labels))/max((gt_labels - min(gt_labels))).to(torch.int)
gt_labels = gt_labels.numpy()

Putting MNIST data on torus

In [None]:
circ_1_coordinates = torus.factors[0].intrinsic_to_extrinsic_coords(encoded_angles[:,0]).reshape(2,-1).T
circ_2_coordinates = torus.factors[1].intrinsic_to_extrinsic_coords(encoded_angles[:,1]).reshape(2,-1).T
#print("1st", circ_1_coordinates)
#print("2nd", circ_2_coordinates)

In [None]:
MNIST_data_on_torus_4d = np.concatenate((circ_1_coordinates,circ_2_coordinates),axis = 1).reshape(-1,2,2) # cos\phi, sin \phi, cos \psi, sin \psi

In [None]:
kmeans = RiemannianKMeans(torus, N, tol=1e-3)
kmeans.fit(MNIST_data_on_torus_4d)
kmeans_latent_space_euclidean_labels = kmeans.labels_
cluster_centers = kmeans.centroids_# kmeans.cluster_centers_

In [None]:
fig,(ax1,ax2) = plt.subplots(2,1,figsize=(8, 12))
p1 = ax1.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=kmeans_latent_space_euclidean_labels, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.colorbar(p1,ticks=range(N))
ax1.title.set_text("Latent space colored by K-means on Torus with Euclidean metric")
ax1.grid(True)

correcltly_detected_labels = abs(kmeans_latent_space_euclidean_labels - gt_labels)
if correcltly_detected_labels.sum() < len(gt_labels)//2:
    correcltly_detected_labels = np.logical_not(correcltly_detected_labels)

p2 = ax2.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=correcltly_detected_labels, marker='o', edgecolor='none', cmap=plt.cm.get_cmap("viridis", N))
cbar = plt.colorbar(p2,ticks=[0.25,0.75])
cbar.ax.set_yticklabels(["incorrect","correct"]) 
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/Kmeans_latent_space.pdf",format="pdf")

# 3. K-means in input data space

In [None]:
kmeans_input_space = KMeans(n_clusters=N, random_state=0, n_init="auto").fit(input_dataset.reshape(-1,D).detach())
kmeans_input_space_labels = kmeans_input_space.labels_
print(f"k-means clusterisation to {N} clusters")

In [None]:
#plt.figure(figsize=(8, 6))
fig,(ax1,ax2) = plt.subplots(2,1,figsize=(8, 12))
p1 = ax1.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=kmeans_input_space_labels, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.colorbar(p1,ticks=range(N))
ax1.title.set_text(f"K-means clusterization on input data, K = {N}, \n Euclidean metric in input space $R^D$")
ax1.grid(True)

correcltly_detected_labels = abs(kmeans_input_space_labels - gt_labels)
if correcltly_detected_labels.sum() < len(gt_labels)//2:
    correcltly_detected_labels = np.logical_not(correcltly_detected_labels)

p2 = ax2.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=correcltly_detected_labels, marker='o', edgecolor='none', cmap=plt.cm.get_cmap("viridis", N))
cbar = plt.colorbar(p2,ticks=[0.25,0.75])
cbar.ax.set_yticklabels(["incorrect","correct"]) 
ax1.title.set_text(f"K-means clusterization on input data, K = {N}, \n Euclidean metric in input space $R^D$")
ax1.grid(True)
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/Kmeans_input_space.pdf",format="pdf")

# 4. K-means in a output data space

In [None]:
kmeans_recon_space = KMeans(n_clusters=N, random_state=0, n_init="auto").fit(recon_dataset.detach())
kmeans_recon_space_labels = kmeans_recon_space.labels_
print(f"k-means clusterisation to {N} clusters")

In [None]:
plt.figure(figsize=(8, 6))

plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=kmeans_recon_space_labels, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.colorbar(ticks=range(N))
plt.title(f"K-means clusterization on reconstructed data, K = {N}, \n Euclidean metric in output space $R^D$")
plt.grid(True)
if violent_saving == True:
    plt.savefig(f"{Path_pictures}/Kmeans_latent_space.pdf",format="pdf")

# 5. F-score comparison

In [None]:
kmeans_latent_space_euclidean_permuted_labels = abs(kmeans_latent_space_euclidean_labels - 1)
kmeans_recon_space_permuted_labels = abs(kmeans_recon_space_labels - 1)

kmeans_input_space_permuted_labels = abs(kmeans_input_space_labels - 1)

In [None]:


F_score_latent_space_eucl = max(sklearn.metrics.f1_score(gt_labels,kmeans_latent_space_euclidean_labels),
              sklearn.metrics.f1_score(gt_labels,kmeans_latent_space_euclidean_permuted_labels))

F_score_input_space_eucl = max(sklearn.metrics.f1_score(gt_labels,kmeans_input_space_labels),
              sklearn.metrics.f1_score(gt_labels,kmeans_input_space_permuted_labels))

F_score_recon_space_eucl = max(sklearn.metrics.f1_score(gt_labels,kmeans_recon_space_labels),
              sklearn.metrics.f1_score(gt_labels,kmeans_recon_space_permuted_labels))
curv_w = json_config["losses"]["curv_w"]
print(f"Curvature penalization weight: {curv_w}")
print(f"F-score Euclidean k-means in latent space vs ground truth: \n{F_score_latent_space_eucl}")
print(f"F-score Euclidean k-means in input data space vs ground truth: \n{F_score_input_space_eucl}")
print(f"F-score Euclidean k-means in reconstructed data space vs ground truth: \n{F_score_recon_space_eucl}")

In [None]:
dict = {
    "curv_w" : curv_w,
    "labels": selected_labels,
    "F-score Euclidean k-means in latent space vs ground truth" : F_score_latent_space_eucl,
    "F-score Euclidean k-means in reconstructed data space vs ground truth" : F_score_recon_space_eucl,
    "F-score Euclidean k-means in input data space vs ground truth" : F_score_input_space_eucl
}
with open(f'{Path_pictures}/K-means_exp{experiment_number}.json', 'w') as json_file:
    json.dump(dict, json_file, indent=4)