This notebook provides the first naive benchmark for the geodesics in the latent space of the autoencoder. The benchmarks evaluates the effect of curvature regularization, measuring how different from straight lines are the geodesics. Namely for a curve $\gamma(t)$ one can consider a functional:
\begin{equation}
    E_{ij} = \int\limits_0^1 \|\gamma'' (t)\|^2 dt \ ,
\end{equation}
where $\gamma(t) = \Psi(t \Phi_{\theta}(X_i) + (1-t) \Phi_{\theta}(X_j))$, recall $\Psi$ is the decoder function of the autoencoder. We obtain the average Euclidean energy:
\begin{equation}
    \mathcal{E} = \frac{1}{\binom{K}{2}} \sum\limits_{1 \leq i < j \leq K} E_{ij} \ .
\end{equation}

In [None]:
dataset_name = "Swissroll"
#dataset_name = "Synthetic"

K = 50

# The first benchmark 

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# adding path to the set generating package
import sys
sys.path.append('../') # have to go 1 level up

## Dataset downloading

In [None]:
import sklearn
from sklearn import datasets

# Set hyperparameters 
# Swiss roll
sr_noise = 0.05
sr_numpoints = 18000 #k*n
split_ratio = 0.2

# Synthetic
d = 2         # latent space dimension
k = 3         # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
shift_class = 0
variance_of_classes = 1 # variation of each Gaussian initially 0.1
interclass_variance = 0.1 # this creates a Gaussian, 
# i.e.random shift 
# proportional to the value of interclass_variance
# initially 0.1
torch.manual_seed(0)

#K = 50

if dataset_name == "Synthetic":
    D = 784
    my_dataset = ricci_regularization.SyntheticDataset(k=k,n=n,d=d,D=D,
                                        shift_class=shift_class,
                                        variance_of_classes = variance_of_classes, 
                                        interclass_variance=interclass_variance)

    train_dataset = my_dataset.create
elif dataset_name == "Swissroll":
    D = 3
    train_dataset =  sklearn.datasets.make_swiss_roll(n_samples=sr_numpoints, noise=sr_noise, random_state=1)
    sr_points = torch.from_numpy(train_dataset[0]).to(torch.float32)
    #sr_points = torch.cat((sr_points,torch.zeros(sr_numpoints,D-3)),dim=1)
    sr_colors = torch.from_numpy(train_dataset[1]).to(torch.float32)
    from torch.utils.data import TensorDataset
    train_dataset = TensorDataset(sr_points,sr_colors)

m = len(train_dataset)

train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])
#test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
#train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

l = test_data[:][0].shape[0]

first_benchmark_data,rest = torch.utils.data.random_split(test_data, [K, l - K])

## introducing gamma 

In [None]:
from torch import nn
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, hidden_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        return out

class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        self.activation = torch.sin
        #self.activation = torch.nn.ReLU()
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        #out = torch.sigmoid(y)
        return out

In [None]:
encoder = Encoder(input_dim=D, hidden_dim=d)
decoder = Decoder(hidden_dim=d, output_dim=D)

In [None]:
def gamma (t, x, y):
    return decoder(x*t + y*(1-t))

In [None]:
#gamma (torch.arange(10).reshape(10,1), torch.rand(d), torch.rand(d))

In [None]:
def gamma_second(t, h,  x, y):
    return (gamma(t+h, x, y) - 2*gamma(t, x, y) + gamma(t-h, x, y))/(h**2)

 $\frac{1}{n-1}\sum\limits_{i=0}^{n-2} \frac{\gamma(t_i) + \gamma(t_{i+1})}{2}$

In [None]:
def E(x_i,x_j,n_partition):
    n = n_partition
    segment_partition = (1/(n-1))*torch.arange(n,dtype=torch.float32)
    gamma_second_array = gamma_second (segment_partition.reshape(-1,d),
              h = 1/n, x = x_i, y = x_j)
    gamma_second_norm_array = gamma_second_array.norm(dim=1)
    E_ij = (0.5/(n-1))*torch.sum((gamma_second_norm_array[:-1]+gamma_second_norm_array[1:]))
    # return E_ij.item() doesnot work with vmap
    return E_ij

In [None]:
torch.manual_seed(10)
E(torch.rand(d),torch.rand(d), n_partition = 10)

### vmap vectorization

In [None]:
E_vmap = torch.func.vmap(E)

In [None]:
l-K

In [None]:
# more efficiently
import math

#K = 100 # this can go up to 300 in practice

def make_pairs():
    #first_bencmark_data_in_ls = encoder(first_benchmark_data[:][0])
    initial_points = first_benchmark_data[:][0]
    start_points_list = []
    end_points_list = []
    for i in range(K):
        for j in range(i+1,K):
            start_points_list.append(initial_points[i].unsqueeze(0))
            end_points_list.append(initial_points[j].unsqueeze(0))
            #start_points_list.append(first_bencmark_data_in_ls[i].unsqueeze(0))
            #end_points_list.append(first_bencmark_data_in_ls[j].unsqueeze(0))
    start_points = torch.cat(start_points_list, dim = 0)
    end_points = torch.cat(end_points_list)
    return start_points, end_points
# now let us build C_K^2 pairs of start and end points we want to
"""
start_points = first_bencmark_data_in_ls.repeat((math.ceil(K/2),1))

def cyclic_perm (tensor):
    new_tensor = torch.cat((tensor[1:],tensor[0].unsqueeze(0)),dim=0)
    return new_tensor

end_points_list = []

first_bencmark_data_in_ls_permuted = first_bencmark_data_in_ls
for i in range(math.ceil(K/2)):
    first_bencmark_data_in_ls_permuted = cyclic_perm(first_bencmark_data_in_ls_permuted)
    end_points_list.append(first_bencmark_data_in_ls_permuted)

end_points = torch.cat(end_points_list)

if (K%2 == 0):
    start_points = start_points[:-(K//2)]
    end_points = end_points[:-(K//2)]
"""
# finally we get K*(K-1)/2 pairs of points

In [None]:
# no penalty on curvature
load_weight_name = "swissroll_curv_w=0_ls=R^2"
PATH_enc = f'../nn_weights/encoder_{load_weight_name}'
encoder.load_state_dict(torch.load(PATH_enc))
encoder.eval()
PATH_dec = f'../nn_weights/decoder_{load_weight_name}'
decoder.load_state_dict(torch.load(PATH_dec))
decoder.eval()

start_points, end_points = make_pairs()
Energy_pairwise_array_no_curv_pen = E_vmap(encoder(start_points),encoder(end_points),n_partition=100)
Distance_ls_pairwize_array_no_curv_pen = (encoder(end_points) - encoder(start_points)).norm(dim=1)
Distance_RD_pairwize_array_no_curv_pen = (end_points - start_points).norm(dim=1)

In [None]:
# with penalty on curvature
load_weight_name = "swissroll_curv_w=1_ls=R^2"
#load_weight_name = "swissroll_curv_w=10_ls=R^2"
#load_weight_name = "swissroll_curv_w=10_ls=R^2_20epochs_bs=32"
PATH_enc = f'../nn_weights/encoder_{load_weight_name}'
encoder.load_state_dict(torch.load(PATH_enc))
encoder.eval()
PATH_dec = f'../nn_weights/decoder_{load_weight_name}'
decoder.load_state_dict(torch.load(PATH_dec))
decoder.eval()

start_points, end_points = make_pairs()

Energy_pairwise_array_with_curv_pen = E_vmap(encoder(start_points),encoder(end_points),n_partition=100)
Distance_ls_pairwize_array_with_curv_pen = (encoder(end_points) - encoder(start_points)).norm(dim=1)
Distance_RD_pairwize_array_with_curv_pen = (end_points - start_points).norm(dim=1)
#Energy_pairwise_array_with_curv_pen = E_vmap(start_points,end_points,n_partition=100)
#Distance_ls_pairwize_array_with_curv_pen = (end_points - start_points).norm(dim=1)
#Distance_RD_pairwize_array_with_curv_pen = (decoder(end_points) - decoder(start_points)).norm(dim=1)

In [None]:
# with penalty on curvature

load_weight_name = "swissroll_curv_w=10_ls=R^2_20epochs_bs=32"
PATH_enc = f'../nn_weights/encoder_{load_weight_name}'
encoder.load_state_dict(torch.load(PATH_enc))
encoder.eval()
PATH_dec = f'../nn_weights/decoder_{load_weight_name}'
decoder.load_state_dict(torch.load(PATH_dec))
decoder.eval()

start_points, end_points = make_pairs()

Energy_pairwise_array_curv_w10_pen = E_vmap(encoder(start_points),encoder(end_points),n_partition=100)


In [None]:
Distance_RD_pairwize_array_with_curv_pen

In [None]:
Distance_RD_pairwize_array_no_curv_pen

# ground truth check

In [None]:
plt.hist((Distance_RD_pairwize_array_no_curv_pen-Distance_RD_pairwize_array_with_curv_pen).detach(), bins = 50)
plt.title("Distances in $\mathbb{R}^D$ with and without curvature penalization")
plt.show()

In [None]:
plt.hist((Distance_ls_pairwize_array_no_curv_pen-Distance_ls_pairwize_array_with_curv_pen).detach(), bins = 50)
plt.title("Distances in latent space with and without curvature penalization")
plt.show()

## plotting

In [None]:
plt.rcParams.update({'font.size': 24})
fig, ax = plt.subplots(figsize=(9,9),dpi = 300)
#plt.title(f"Swissroll: Energy of $C_{{{K}}}^2$ paths $E_{{ij}}$ vs distance in $\mathbb{{R}}^3$")
ax.scatter(Distance_RD_pairwize_array_no_curv_pen.detach().numpy(),Energy_pairwise_array_no_curv_pen.detach().numpy(),
            color = "red", s = 10, label = "$\lambda_{curv} = 0$")
ax.scatter(Distance_RD_pairwize_array_with_curv_pen.detach().numpy(),Energy_pairwise_array_with_curv_pen.detach().numpy(),
            color = "blue", s = 10, label = "$\lambda_{curv} = 1$")
ax.scatter(Distance_RD_pairwize_array_with_curv_pen.detach().numpy(),Energy_pairwise_array_curv_w10_pen.detach().numpy(),
            color = "green", s = 10, label = "$\lambda_{curv} = 10$")
ax.legend(loc='upper left')
#ax.set_xlabel('$\|\Psi \circ \Theta(X_i) - \Psi \circ \Theta(X_j)\|_2$, $\Psi \circ \Theta(X_i)\in \mathbb{R}^3$')
#ax.set_xlabel('$\|X_i - X_j\|_2$, $X_i \in \mathbb{R}^3$')
ax.set_xlabel('$\|X_i - X_j\|_2$')
ax.set_ylabel('$E_{ij}$')
fig.savefig("Scatterplot_E_ij_dist_R3_swissroll.pdf",bbox_inches='tight', format = "pdf")
plt.show()

In [None]:
plt.rcParams.update({'font.size': 24})
fig, ax = plt.subplots(figsize=(9,9),dpi = 300)
#plt.title(f"Swissroll: enegry of paths change $E_{{ij}}^{{\lambda_{{curv}} = 1}} - E_{{ij}}^{{\lambda_{{curv}} = 0}}$ \n vs distance in $\mathbb{{R}}^3$ for $C_{{{K}}}^2$ paths")
ax.scatter(Distance_RD_pairwize_array_with_curv_pen,
           (Energy_pairwise_array_with_curv_pen - Energy_pairwise_array_no_curv_pen).detach().numpy(),
            color = "magenta", s = 10)
#ax.set_xlabel('$\|\Psi \circ \Theta(X_i) - \Psi \circ \Theta(X_j)\|_2$, $\Psi \circ \Theta(X_i)\in \mathbb{R}^3$')
#ax.set_xlabel('$\|X_i - X_j\|_2$, $X_i \in \mathbb{R}^3$')
ax.set_xlabel('$\|X_i - X_j\|_2$')
ax.set_ylabel('Change in $E_{ij}$')
#ax.set_ylabel('$E_{ij}^{\lambda_{curv} = 1} - E_{ij}^{\lambda_{curv} = 0}$')
fig.savefig("Scatterplot_change_E_ij_dist_R3_swissroll.pdf",bbox_inches='tight', format = "pdf")
plt.show()

### logscale

In [None]:
plt.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(figsize=(9,9),dpi = 300)
plt.title(f"Swissroll: Energy of $C_{{{K}}}^2$ paths $E_{{ij}}$ vs distance in $\mathbb{{R}}^3$")
ax.scatter(Distance_RD_pairwize_array_no_curv_pen.detach().numpy(),Energy_pairwise_array_no_curv_pen.detach().numpy(),
            color = "red", s = 10, label = "$\lambda_{curv} = 0$")
ax.scatter(Distance_RD_pairwize_array_with_curv_pen.detach().numpy(),Energy_pairwise_array_with_curv_pen.detach().numpy(),
            color = "blue", s = 10, label = "$\lambda_{curv} = 1$")
ax.legend(loc='upper left')
#ax.set_xlabel('$\|\Psi \circ \Theta(X_i) - \Psi \circ \Theta(X_j)\|_2$, $\Psi \circ \Theta(X_i)\in \mathbb{R}^3$')
ax.set_xlabel('$\|X_i - X_j\|_2$, $X_i \in \mathbb{R}^3$')
ax.set_ylabel('$E_{ij}$')
#ax.set_xscale('log')
ax.set_yscale('log')
fig.savefig("Scatterplot_E_ij_dist_R3_swissroll_ylogscale.pdf", format = "pdf")
plt.show()

In [None]:
plt.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(figsize=(9,9),dpi = 300)
plt.title(f"Swissroll: Energy of $C_{{{K}}}^2$ paths $E_{{ij}}$ vs distance in latent space")
ax.scatter(Distance_ls_pairwize_array_no_curv_pen.detach().numpy(),Energy_pairwise_array_no_curv_pen.detach().numpy(),
            color = "red", s = 10, label = "$\lambda_{curv} = 0$")
ax.scatter(Distance_ls_pairwize_array_with_curv_pen.detach().numpy(),Energy_pairwise_array_with_curv_pen.detach().numpy(),
            color = "blue", s = 10, label = "$\lambda_{curv} = 1$")
ax.legend(loc='upper left')
ax.set_xlabel('$\|\Theta(X_i) - \Theta(X_j)\|_2$, $\Theta(X_i) \in \mathbb{R}^2$')
ax.set_ylabel('$E_{ij}$')
fig.savefig("Scatterplot_E_ij_dist_ls_swissroll.pdf", format = "pdf")
plt.show()