# Riemannian K means
Steps:

1) Guess initial cluster centers (Euclidean K-means)
2) Shift the local chart center to cluster basepoints
3) Do log maps with base points
4) Recompute cluster centers in log maps
5) Return (only cluster centers) on the manifold by exp maps with corresponding 
base points
6) unshift
7) recluster
8) check tolerance (cluster center shift):
go to step 1. repeat until clusters are stable

# Data loading

In [None]:
from tqdm.notebook import tqdm
import torch
import numpy as np
import ricci_regularization
import matplotlib.pyplot as plt
import matplotlib
import stochman
from stochman.manifold import EmbeddedManifold
from stochman.curves import CubicSpline

violent_saving = True

#experiment_json = f'../experiments/MNIST_torus_AEexp34.json' # no curv_pen

experiment_json = f'../experiments/MNIST01_torus_AEexp7.json'
mydict = ricci_regularization.get_dataloaders_tuned_nn(Path_experiment_json=experiment_json)

In [None]:
torus_ae = mydict["tuned_neural_network"]
test_loader = mydict["test_loader"]
json_cofig = mydict["json_config"]
Path_pictures = json_cofig["Path_pictures"]
exp_number = json_cofig["experiment_number"]
curv_w = json_cofig["losses"]["curv_w"]

In [None]:
D = 784
k = json_cofig["dataset"]["parameters"]["k"]
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
feature_space_encoding_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1,D)))
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(k,"jet"))
plt.show()

# 1.Initial guess via Euclidean K-means via Geomstats

In [None]:
#this adds an environmental variable
#%env GEOMSTATS_BACKEND=pytorch

import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.kmeans import RiemannianKMeans

In [None]:
circumference1 = Hypersphere(dim=1)
circumference2 = Hypersphere(dim=1)

# Building torus as a product $\mathcal{T} = \mathcal{S}^1 \times \mathcal{S}^1$ 

In [None]:
from geomstats.geometry.product_manifold import ProductManifold
torus = ProductManifold((circumference1,circumference2))

Loading saved points and labels and plotting ground truth labels

In [None]:
encoded_angles = encoded_points_no_grad #torch.load("encoded_angles.pt")
gt_labels = color_array #torch.load("labels.pt")
#convert dt_labels into 0 and 1 array
gt_labels = (gt_labels - min(gt_labels))/max((gt_labels - min(gt_labels))).to(torch.int)
gt_labels = gt_labels.numpy()

Putting MNIST data on torus

In [None]:
circ_1_coordinates = torus.factors[0].intrinsic_to_extrinsic_coords(encoded_angles[:,0]).reshape(2,-1).T
circ_2_coordinates = torus.factors[1].intrinsic_to_extrinsic_coords(encoded_angles[:,1]).reshape(2,-1).T
#print("1st", circ_1_coordinates)
#print("2nd", circ_2_coordinates)

In [None]:
MNIST_data_on_torus_4d = np.concatenate((circ_1_coordinates,circ_2_coordinates),axis = 1).reshape(-1,2,2) # cos\phi, sin \phi, cos \psi, sin \psi

In [None]:
kmeans = RiemannianKMeans(torus, k, tol=1e-3)
kmeans.fit(MNIST_data_on_torus_4d)
kmeans_latent_space_euclidean_labels = kmeans.labels_
cluster_centers = kmeans.centroids_# kmeans.cluster_centers_

In [None]:
torus.factors[0].extrinsic_to_angle(cluster_centers[:][1])

In [None]:
cluster_centers_in_local_chart = Hypersphere(dim=1).extrinsic_to_intrinsic_coords(cluster_centers).squeeze()

In [None]:
fig,(ax1,ax2) = plt.subplots(2,1,figsize=(8, 12))
p1 = ax1.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=kmeans_latent_space_euclidean_labels, marker='o', edgecolor='none', cmap=ricci_regularization.discrete_cmap(k, 'jet'))
plt.colorbar(p1,ticks=range(k))
ax1.title.set_text("Latent space colored by K-means on Torus with Euclidean metric")
ax1.grid(True)
ax1.scatter(cluster_centers_in_local_chart[:,0],cluster_centers_in_local_chart[:,1],marker = '*',s=150,c ="orange")

correcltly_detected_labels = abs(kmeans_latent_space_euclidean_labels - gt_labels)
if correcltly_detected_labels.sum() < len(gt_labels)//2:
    correcltly_detected_labels = np.logical_not(correcltly_detected_labels)

p2 = ax2.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=correcltly_detected_labels, marker='o', edgecolor='none', cmap=plt.cm.get_cmap("viridis", k))
cbar = plt.colorbar(p2,ticks=[0.25,0.75])
cbar.ax.set_yticklabels(["incorrect","correct"]) 
#if violent_saving == True:
#    plt.savefig(f"{Path_pictures}/Kmeans_latent_space.pdf",format="pdf")

# K-means in input data space

In [None]:
num_points_in_clusters = 100

clusters = []
for i in range(k):
    current_cluster = encoded_points_no_grad[np.where(kmeans_latent_space_euclidean_labels == i)]
    current_cluster = current_cluster[:num_points_in_clusters]
    clusters.append(current_cluster)

In [None]:
import matplotlib.colors as mcolors
colors = list(mcolors.TABLEAU_COLORS.keys())
for i in range(k):
    plt.scatter(clusters[i][:,0],clusters[i][:,1],c = colors[i%len(colors)],s=30)
    plt.scatter(cluster_centers_in_local_chart[i,0],cluster_centers_in_local_chart[i,1],marker = '*',s=250,c =colors[i%len(colors)],edgecolors="black")


In [None]:
shifted_clusters = []
clusters_in_R2 = []
for i in range(k):
    # move center to (0,0)
    shifted_cluster = clusters[i] - cluster_centers_in_local_chart[i]
    # fit clusters into (-\pi,\pi)\times (-\pi,\pi)
    shifted_cluster = torch.remainder(shifted_cluster + torch.pi, 2 * torch.pi) - torch.pi
    cluster_in_R2 = shifted_cluster + cluster_centers_in_local_chart[i]
    clusters_in_R2.append(cluster_in_R2)
    shifted_clusters.append(shifted_cluster)

# Shifts

In [None]:
fig,axes = plt.subplots(nrows=k,dpi=300,figsize = (6,k*6))
axes[0].set_title("Clusters with centers shifted to 0")
for i in range(k):
    axes[i].scatter(shifted_clusters[i][:,0],shifted_clusters[i][:,1],c = colors[i%len(colors)],label = f"Points of cluster # {i}")
    axes[i].scatter(clusters_in_R2[i][:,0],clusters_in_R2[i][:,1],c = colors[i%len(colors)],label = f"Points of cluster # {i} in the universal cover",edgecolor ="black")
    axes[i].scatter(cluster_centers_in_local_chart[i,0],cluster_centers_in_local_chart[i,1],marker = '*',s=250,c = "magenta", label = f"Cluster # {i} center." )
    axes[i].set_xlim(-3/2*torch.pi,3/2*torch.pi)
    axes[i].set_ylim(-3/2*torch.pi,3/2*torch.pi)
    axes[i].set_xticks(torch.linspace(-3/2*torch.pi,3/2*torch.pi,7))
    axes[i].set_yticks(torch.linspace(-3/2*torch.pi,3/2*torch.pi,7))
    axes[i].grid()
    axes[i].legend(loc="lower left")
    #shifted_cluster = clusters[i] + cluster_centers_in_local_chart[i]
    #shifted_clusters.append(shifted_cluster)

# 3-4.Logarithmic maps and new baricenters in logmaps

In [None]:
# metric needs to be shifted!!!
# for this we can compute logmaps wrt the periodic metric on the universal cover of the torus

In [None]:
from stochman.manifold import EmbeddedManifold
# geodesics are computed minimizing "energy" in the embedding of the manifold,
# So no need to compute the Pullback metric. and thus the algorithm is fast
class Autoencoder(EmbeddedManifold):
    def embed(self, c, jacobian = False):
        return torus_ae.decoder_torus(c)
#selected_labels = json_cofig["dataset"]["selected_labels"]
model = Autoencoder()

In [None]:
clusters_logmaps = []
# vectorize this!!!
for i in range(k):
    clusters_logmap = model.logmap(torch.from_numpy(cluster_centers_in_local_chart[i]).repeat(num_points_in_clusters,1),
                                         clusters_in_R2[i])
    clusters_logmaps.append(clusters_logmap)

In [None]:
baricenters_logmap = []
#vectorize this
for i in range(k):
    baricenters_logmap.append(torch.mean(clusters_logmaps[i],dim=0))

In [None]:
fig,axes = plt.subplots(nrows=k,dpi=300,figsize = (6,k*6))
axes[0].set_title("Clusters after log maps with base points at cluster centers")
for i in range(k):
    axes[i].scatter(clusters_logmaps[i][:,0],clusters_logmaps[i][:,1],c = colors[i%len(colors)],label = f"Points of cluster # {i} in the logmaps cover")
    axes[i].scatter(baricenters_logmap[i][0],baricenters_logmap[i][1],marker = '*',s=150,c = "magenta", label = f"New cluster # {i} baricenter:({baricenters_logmap[i][0]:.4f},{baricenters_logmap[i][1]:.4f})." )
    axes[i].legend(loc="lower left")
    #shifted_cluster = clusters[i] + cluster_centers_in_local_chart[i]
    #shifted_clusters.append(shifted_cluster)

# 5-6.return baricenters via exp map + unshift
i.e shoot a geodesic from the base point $p_i$ (old baricenter in the universal cover) and speed which is the new baricenter in logmap

# Geodesic shooting

In [None]:
def geod_vect(x,dxdt):
    u = x
    v = dxdt
    dudt = v
    n = v.shape[0]
    dvdt = torch.zeros(n,2)
    Ch_at_u = ricci_regularization.Ch_jacfwd_vmap(u,function=torus_ae.decoder_torus,device=torch.device("cpu"))
    for l in range(2):
        for i in range(2):
            for j in range(2):
                dvdt[:,l] -= Ch_at_u[:,l,i,j] * v[:,i] * v[:,j]
    return dudt, dvdt

def rungekutta_vect(f, initial_point_array, initial_speed_array, t, args=()):
    n = len(t)
    #num_geodesics = initial_point_array.shape[0]
    x = torch.zeros((n, *tuple(initial_point_array.shape)))
    dxdt = torch.zeros((n, *tuple(initial_speed_array.shape)))
    x[0] = initial_point_array
    dxdt[0] = initial_speed_array
    #with torch.no_grad():
    #    curve_length = torch.zeros(num_geodesics)
    for i in range(n - 1):
        dudt, dvdt = f(x[i], dxdt[i], *args)
        
        #print()
        x[i+1] = x[i] + (t[i+1] - t[i])*dudt
        dxdt[i+1] = dxdt[i] + (t[i+1] - t[i])*dvdt
        
        
        #dxdt_length = torch.sqrt(((dxdt[i].unsqueeze(-2))@metric@(dxdt[i].unsqueeze(-1))).squeeze())
        #curve_length =+ dxdt_length
    return x, dxdt
    #return x, dxdt,curve_length
# x is of shape [num_grid_points,num_geodesics,dimension=2]

In [None]:
num_approximation_points = 101 # how good the approximation is
max_parameter_value = 1 #3 # how far to go
time_array = torch.linspace(0, max_parameter_value, num_approximation_points)

#starting_points = torch.tensor([-2.,0.]).repeat(num_geodesics,1) # common starting point
starting_points = torch.from_numpy(cluster_centers_in_local_chart)
starting_speeds = torch.cat(baricenters_logmap).reshape(k,2)
geodesics2plot,_ = rungekutta_vect(f=geod_vect,initial_point_array=starting_points,
                                   initial_speed_array=starting_speeds,t=time_array)
geodesics2plot = geodesics2plot.detach()

In [None]:
for i in range(k):
    plt.plot(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c="black")
#plt.colorbar(label="scalar curvature along geodesics")
plt.show()

In [None]:
start = torch.tensor(cluster_centers_in_local_chart)
end = geodesics2plot[-1,:]
(start-end).norm(dim=1)

In [None]:
c,success = model.connecting_geodesic(start,end)
print(success.item())

In [None]:
length = model.curve_length(c(time_array)).detach()
print(f"Geodesic length of cluster center shifts:\n{length}")