# I. Hyperpameters: set and learning

In [None]:
import os

#experiment_name = "current experiment"
experiment_name = "swissroll_curv_func_oscilations_curv_w=10_eps=0.01"
#experiment_name = "synthetic_curv_w=1e+1_ls=R^2+OOD"
#experiment_name = "swissroll_OOD_curv_w=10_sigma_ood=0.2_T_OOD=20_OOD_w=10_20epochs"
#experiment_name = "synthetic_curv_w=1e+1_ls=R^2"


build_report = True
weights_loaded = True
violent_saving = True # if False it will not save plots
model_weights_saved = False
load_weight_name = experiment_name
#load_weight_name = "current_experiment"
save_weight_name = experiment_name
#save_weight_name = "swissroll_curv_w=1_ls=R^2_40epochs_bs=16"
# here you can choose a path for saving the pictures 
Path_pictures = f"/home/alazarev/CodeProjects/Experiments/{experiment_name}"
if os.path.exists(Path_pictures) == False:
    os.mkdir(Path_pictures) # needs to be commented once the folder for plots is created

In [None]:
# Hyperparameters for dataset
#set_name = "Synthetic"
#where_to_compute_curv = "random"
where_to_compute_curv = "batch"

set_name = "Swissroll"
sr_noise = 0.05
sr_numpoints = 18000 #k*n

#D = 784       #dimension
#D = 2
if set_name == "Swissroll":
    D = 3 # for swissroll
elif set_name == "Synthetic":
    D = 784
d = 2         # latent space dimension
k = 3         # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
shift_class = 0
var_class = 1 # variation of each Gaussian initially 0.1
intercl_var = 0.1 # this creates a Gaussian, 
# i.e.random shift 
# proportional to the value of intercl_var
# initially 0.1

In [None]:
klw = 0 # AE mode on

#klw = 5e-3 #VAE mode on

In [None]:
# Hyperparameters for data loaders
batch_size  = 32 # was 32 initially
split_ratio = 0.2

# Set manual seed for reproducibility
# torch.manual_seed(0)

In [None]:
mse_w = 1.0
curv_w = 10.0 #weight on curvature
compute_curvature = True

### Define the loss function
#loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr         = 2e-5 #initially 4e-5 for synthetic, 1e-4 for swissroll
momentum   = 0.8 #initially 0.8
num_epochs = 20 #initially 40
batches_per_plot = 400 #initially 200 

### Set the random seed for reproducible results
# torch.manual_seed(0)

### Imports

In [None]:
# Minimal imports
import math
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

# adding path to the set generating package
import sys
sys.path.append('../') # have to go 1 level up
import ricci_regularization as RR
#import torchvision


import sklearn
from sklearn import datasets

# I*. Choosing Dataset

In [None]:
if set_name == "Synthetic":
    # Generate dataset 

    # old style
    # train_dataset = ricci_regularization.generate_dataset(D, k, n, shift_class=shift_class, intercl_var=intercl_var)

    # via classes
    torch.manual_seed(0) # reproducibility
    my_dataset = RR.SyntheticDataset(k=k,n=n,d=d,D=D,
                                        shift_class=shift_class,
                                        var_class = var_class, 
                                        intercl_var=intercl_var)

    train_dataset = my_dataset.create
elif set_name == "Swissroll":
    train_dataset =  sklearn.datasets.make_swiss_roll(n_samples=sr_numpoints, noise=sr_noise, random_state=1)
    sr_points = torch.from_numpy(train_dataset[0]).to(torch.float32)
    #sr_points = torch.cat((sr_points,torch.zeros(sr_numpoints,D-3)),dim=1)
    sr_colors = torch.from_numpy(train_dataset[1]).to(torch.float32)
    from torch.utils.data import TensorDataset
    train_dataset = TensorDataset(sr_points,sr_colors) 

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])
test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
# test_data[:][0] will give the vectors of data without labels from the test part of the dataset

# II. AE declatation and initialization

### Declaration

In [None]:
# Check if the GPU is available
cuda_on = torch.cuda.is_available()
if cuda_on:
    device  = torch.device("cuda") 
else :
    device = torch.device("cpu")
print(f'Selected device: {device}')

In [None]:

# sine AE
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, hidden_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        return out
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        self.activation = torch.sin
        #self.activation = torch.nn.ReLU()
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        #out = torch.sigmoid(y)
        return out
"""
            
# RelU-sine AE
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, hidden_dim)
        self.activation = nn.ReLU()
        #self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        #out = torch.sin(out)
        return out
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        #self.activation = torch.sin
        self.activation = torch.nn.ReLU()
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        #out = torch.sigmoid(y)
        return out
"""

In [None]:
"""
# initial structure
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, hidden_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        out = self.activation(out)
        return out
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        out = self.activation(out)
        #out = torch.sigmoid(y)
        return out
"""

In [None]:
class VariationalEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, cuda=True):
        super(VariationalEncoder, self).__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, hidden_dim)
        self.linear3 = nn.Linear(512, hidden_dim)
        
        self.N = torch.distributions.Normal(0, 1)
        if cuda:
            self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
            self.N.scale = self.N.scale.cuda()
        self.kl = 0
    
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        #x = torch.nn.functional.relu(self.linear1(x))
        x = torch.sin(self.linear1(x)) 
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z

### Initialization 

In [None]:
# initially D=784, d=2
# AE/VAE switch
if klw > 0:
    encoder = VariationalEncoder(input_dim=784, hidden_dim=d, cuda=cuda_on)
else:
    encoder = Encoder(input_dim=D, hidden_dim=d)
decoder = Decoder(hidden_dim=d, output_dim=D)

# ClassicalAE
#encoder = Encoder(input_dim=D, hidden_dim=d)
#decoder = Decoder(hidden_dim=d, output_dim=D)

params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optimizer = torch.optim.RMSprop(params_to_optimize, lr=lr, momentum=momentum, weight_decay=0.0)

#optimizer = torch.optim.Adam(params_to_optimize, lr=lr,weight_decay=0.0)

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

### loading weights

In [None]:
if weights_loaded==True:
    PATH_enc = f'../nn_weights/encoder_{load_weight_name}'
    encoder.load_state_dict(torch.load(PATH_enc))
    encoder.eval()
    PATH_dec = f'../nn_weights/decoder_{load_weight_name}'
    decoder.load_state_dict(torch.load(PATH_dec))
    decoder.eval()

### choice of curvature functional

In [None]:
def Func(encoded_data):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=decoder)
    det_on_data = torch.det(metric_on_data)
    Sc_on_data = RR.Sc_jacfwd_vmap(encoded_data,
                                           function=decoder)
    N = metric_on_data.shape[0]
    Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*torch.square(Sc_on_data)).sum()
    return Integral_of_Sc
"""
# minimizing |g-I|_F
def Func(encoded_data):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=decoder)
    N = metric_on_data.shape[0]
    func = (1/N)*(metric_on_data-torch.eye(d)).norm(dim=(1,2)).sum()
    return func
"""

# III. Plotting tools

In [None]:
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:
"""
    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)
"""

In [None]:
def point_plot(encoder, data, batch_idx, show_title = True, colormap = 'jet',s=40,draw_grid = True,figsize = (8, 6)):

    labels = data[:][1]
    data   = data[:][0]

    # Encode
    encoder.eval()
    with torch.no_grad():
        data = data.view(-1,D) # reshape the img
        data = data.to(device)
        encoded_data = encoder(data)

    # Record codes
    latent = encoded_data.cpu().numpy()
    labels = labels.numpy()

    #Plot
    plt.figure(figsize=figsize)

    if set_name == "Swissroll":
        plt.scatter( latent[:,0], latent[:,1],s=s, c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=colormap)
    else:
        plt.scatter( latent[:,0], latent[:,1],s=s, c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=discrete_cmap(k, colormap))
        #plt.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, marker='o')
        plt.colorbar(ticks=range(k),orientation='vertical',shrink = 0.7)
    if show_title == True:
        plt.title( f'''Latent space for test data in AE at batch {batch_idx}''')
    axes = plt.gca()
    plt.grid(draw_grid)
    
    return plt

In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(train_loader) )
print( "batch size:", batch_size )
print( "product: ", len(train_loader)*batch_size )
print( "-- To be compared to:", (1.0-split_ratio)*n*k)


In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(test_loader) )
print( "batch size:", batch_size )
print( "product: ", len(test_loader)*batch_size )
print( "-- To be compared to:", split_ratio*n*k)

# IV. Training

In [None]:
num_plots = 0 # to enumerate the plots
batch_idx = 0
batches_per_epoch = len(train_loader)

mse_loss = []
kl_loss = []
curv_loss = []
test_mse_loss_list = []
test_curv_loss_list = []

# to iterate though the batches of test data 
# simoultanuousely with train data
iter_test_loader = iter(test_loader)
      
for epoch in range(num_epochs):
   if weights_loaded == True:
      break
   # Set train mode for both the encoder and the decoder
   encoder.train()
   decoder.train()
   
   
   # Iterate the dataloader: no need  for the label
   # values, this is unsupervised learning
   for image_batch, _ in train_loader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
      #shaping the images properly
      image_batch = image_batch.view(-1,D)
      # Move tensor to the proper device
      image_batch = image_batch.to(device)
      # True batch size
      true_batch_size = image_batch.shape[0]

      optimizer.zero_grad()
      
      # Front-propagation
      # -- Encode data
      encoded_data = encoder(image_batch)
      # -- Decode data
      decoded_data = decoder(encoded_data)
      # --Evaluate loss
      mse_loss_batch = torch.sum( (decoded_data-image_batch)**2 )/true_batch_size
      
      if compute_curvature == True:
         if where_to_compute_curv == "batch":
            curvature_train_batch = Func(encoded_data)
         elif where_to_compute_curv == "random":
            curvature_train_batch = Func(2*torch.rand(batch_size,2)-1)
      else:
         curvature_train_batch = 0.0
      
      loss = mse_w*mse_loss_batch + curv_w*curvature_train_batch
      
      # if VAE mode is on
      if klw > 0:
          kl_loss_batch = encoder.kl 
          loss += klw*kl_loss_batch
          kl_loss.append(kl_loss_batch.data)
      else:
          kl_loss_batch = 0.0

      # Backward pass
      loss.backward()
      optimizer.step()
      # Print batch loss
      #print('\t MSE loss per batch (single batch): %f' % (mse_loss_batch.data))
      #print('\t Total loss per batch (single batch): %f' % (loss.data))

      if batch_idx % len(test_loader):
         iter_test_loader = iter(test_loader)
      test_images = next(iter_test_loader)[0].view(-1,D).to(device)
      encoded_test_data = encoder(test_images)
      decoded_test_data = decoder(encoded_test_data)

      # True test_batch size
      true_test_batch_size = test_images.shape[0]
      with torch.no_grad():
         test_mse_loss = torch.sum( (decoded_test_data - test_images)**2 )/true_test_batch_size
         test_mse_loss_list.append(test_mse_loss.detach().cpu().numpy())
         if compute_curvature == True:
            test_curv_loss = Func(encoded_test_data)
            test_curv_loss_list.append(test_curv_loss.detach().cpu().numpy())
         # end if
      # end with
      #print('\t test MSE loss per batch (single batch): %f' % (test_mse_loss.data))
      #print('\t partial train loss (single batch): {:.6} \t curv_loss {:.6} \t mse {:.6}'.format(loss.data, new_loss, only_mse.data))
      
      mse_loss.append(float(mse_loss_batch.detach().cpu().numpy()))
      if compute_curvature == True:
         curv_loss.append(float(curvature_train_batch.detach().cpu().numpy()))

      # Plot and compute test loss      
      if (batch_idx % batches_per_plot == 0):
         #test loss

         #plotting
         plot = point_plot(encoder, test_data, batch_idx)
         if violent_saving == True:
            plot.savefig('../plots/pointplots_in_training_testdata/pp{0}.eps'.format(num_plots),format='eps')
         num_plots += 1
         plot.show()

         # plotting losses
         if batch_idx>0:
            fig, ax1 = plt.subplots()

            ax1.set_xlabel('Batches')
            ax1.set_ylabel('MSE')
            ax1.semilogy(mse_loss, label='train_MSE_loss', color='tab:orange')
            ax1.semilogy(test_mse_loss_list, label='test_MSE_loss', color='tab:red')
            
            ax1.tick_params(axis='y')
            plt.legend(loc='lower left')

            if compute_curvature == True:

               ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

               ax2.set_ylabel('Curvature')  # we already handled the x-label with ax1
               ax2.semilogy(curv_loss, label='train_Curv_loss',color='tab:olive')
               ax2.semilogy(test_curv_loss_list, label='test_Curv_loss', color='tab:green')
               
               ax2.tick_params(axis='y')
               plt.legend(loc='lower right')
            # end if
            fig.tight_layout()  # otherwise the right y-label is slightly clipped
            plt.show()
            # end if
       # end if
      batch_idx += 1
   # end for
   #print('\n EPOCH {}/{}. \t Average values over epoch:\n MSE loss: {}, Curvature loss: {}'.format(epoch + 1, num_epochs, np.mean(mse_loss[-batches_per_epoch:]),np.mean(curv_loss[-batches_per_epoch:])))
   print(f'\n EPOCH {epoch + 1}/{num_epochs}. \t Average values over epoch:\n MSE loss: {np.mean(mse_loss[-batches_per_epoch:])}, Curvature loss: {np.mean(curv_loss[-batches_per_epoch:])}, KL loss: {np.mean(kl_loss[-batches_per_epoch:])}')

# V. Saving the model

In [None]:
if (model_weights_saved == True):
    PATH_enc = f'../nn_weights/encoder_{save_weight_name}'
    torch.save(encoder.state_dict(), PATH_enc)
    PATH_dec = f'../nn_weights/decoder_{save_weight_name}'
    torch.save(decoder.state_dict(), PATH_dec)

# VI. Plots

### determination coefficient and training curves

In [None]:
if set_name == "Swissroll":
    points_tensor = torch.tensor(sr_points)
    cov_matrix = torch.cov(points_tensor.T)
    print("Covariance matrix:\n", cov_matrix)
    mean = points_tensor.mean(dim=0)
    print("Mean vector:", mean)
    R_squared = 1 - ((decoder(encoder(sr_points)).data-sr_points).norm()**2)/((sr_points-mean).norm()**2)
    print("Determination coef R^2:", R_squared)
    MSE = (1/sr_numpoints)*(decoder(encoder(sr_points)).data-sr_points).norm()**2
    print("MSE:", MSE)
    print("tr(Q):", cov_matrix.trace())
    print("MSE/tr(Q):", MSE/cov_matrix.trace())
    print("R^2 = 1-MSE/tr(Q):", 1-MSE/cov_matrix.trace())
elif set_name =="Synthetic":
    points_tensor = train_dataset[:][0]
    cov_matrix = torch.cov(points_tensor.T)
    #print("Covariance matrix:\n", cov_matrix)
    mean = points_tensor.mean(dim=0)
    #print("Mean vector:", mean)
    R_squared = 1 - ((decoder(encoder(points_tensor)).data-train_dataset[:][0]).norm()**2)/((train_dataset[:][0]-mean).norm()**2)
    print("Determination coef R^2:", R_squared)

In [None]:
curv_loss_on_train_data = Func(encoder(train_data[:][0]))

In [None]:
plt.rcParams.update({'font.size': 17}) # makes all fonts on the plot be 24

fig, ax1 = plt.subplots()

ax1.set_xlabel('Batches')
ax1.set_ylabel('MSE')
ax1.semilogy(mse_loss, label='train_MSE_loss', color='tab:orange')
ax1.semilogy(test_mse_loss_list, label='test_MSE_loss', color='tab:red')

ax1.tick_params(axis='y')
plt.legend(loc='lower left')

if compute_curvature == True:
    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
    #ax2.set_ylabel('G-I loss')
    ax2.set_ylabel('Curvature')  # we already handled the x-label with ax1
    ax2.semilogy(curv_loss, label='train_Curv_loss',color='tab:olive')
    ax2.semilogy(test_curv_loss_list, label='test_Curv_loss', color='tab:green')
    ax2.tick_params(axis='y')
    plt.legend(loc='upper right')
fig.tight_layout()  # otherwise the right y-label is slightly clipped
#fig.text(0.05,-0.25,f'The dataset has {k*n} points originated \nby {k} Gaussian(s). \nAverage losses over the last epoch: \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}')
# Average losses over the last epoch
if set_name == "Swissroll":
    fig.text(0.05,-0.25,f"Set params: n={sr_numpoints}, noise={sr_noise}. \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}, \n$R^2=${R_squared:.4f}")
else:    
    fig.text(0.05,-0.25,f"Set params: n={n}, k={k}, d={d}, D={D}, $\sigma$={var_class}, $\sigma_{{I}}$={intercl_var}. \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}, \n $R^2=${R_squared:.4f}, \n Curvature loss over train dataset: {curv_loss_on_train_data:.4f}" )
str_lambda_recon = "$\lambda_{recon}$"
str_lambda_curv = "$\lambda_{curv}$"
plt.title(f"Params: lr={lr}, batch_size={batch_size},\n {str_lambda_recon}={mse_w}, {str_lambda_curv}={curv_w}")
#plt.title("Params: lr={0}, batch_size={1},\n $\lambda_r$={2}, $\lambda_c$={3},{4}".format(lr,batch_size,mse_w,curv_w,str1))
if (violent_saving == True) & (weights_loaded == False):
    plt.savefig(f'{Path_pictures}/losses.pdf',bbox_inches='tight',format='pdf')
#plt.savefig(f'{Path_pictures}/losses.pdf',bbox_inches='tight',format='pdf')
plt.show()

### MSE and metric losses heatmaps

In [None]:
#choose data
#data_for_plot = train_data
data_for_plot = test_data

latent = encoder(data_for_plot[:][0].squeeze()).detach()
labels = data_for_plot[:][1]
init_data = data_for_plot[:][0]
recon_data = decoder(encoder(init_data).detach())
abs_error_array = (recon_data-init_data).squeeze()
mse_array = abs_error_array.norm(dim=1)**2
mse_array = mse_array.detach()
curvature_array = RR.Sc_jacfwd_vmap(latent,function=decoder).detach()

In [None]:
plt.rcParams.update({'font.size': 16})

size_of_points = 20
fig, (ax00,ax0)= plt.subplots(ncols=2, nrows=1,figsize=(15,6),dpi=300)
# (ax3,ax4) can  be added

fig.tight_layout(pad=2.0)

ax00.title.set_text("AE latent space")
if set_name == "Synthetic":
    p00 = ax00.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap=discrete_cmap(k, "jet"))
    fig.colorbar(p00,label="initial color", ticks=(np.arange(k)))
else:
    p00 = ax00.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet')
    fig.colorbar(p00,label="initial color")

ax0.title.set_text("Reconstruction loss")
p0 = ax0.scatter( latent[:,0], latent[:,1], c=mse_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.LogNorm())
fig.colorbar(p0,label="squared l2 norm errors")

"""
ax1.title.set_text("Absolute value of scalar curvature")
p1 = ax1.scatter( latent[:,0], latent[:,1], c=abs(curvature_array), alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.LogNorm())
fig.colorbar(p1,label="curvature abs value")

ax2.title.set_text("Scalar curvature")
p2 = ax2.scatter( latent[:,0], latent[:,1], c=curvature_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.SymLogNorm(linthresh=1e-2))
fig.colorbar(p2,label="curvature")
"""

#ax3.title.set_text("Only negative curvature loss in logscale")
#p3 = ax3.scatter( latent[:,0], latent[:,1], c=-curvature_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.LogNorm())
#fig.colorbar(p3,label="-curvature")

#ax4.title.set_text("Only positive curvature loss in logscale")
#p4 = ax4.scatter( latent[:,0], latent[:,1], c=curvature_array, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.LogNorm())
#fig.colorbar(p4,label="curvature")

if violent_saving == True:
    fig.savefig(f'{Path_pictures}/init_colors_recon_loss.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
metric_array = RR.metric_jacfwd_vmap(encoder(init_data),function=decoder).detach()
det_array = torch.det(metric_array)
trace_array = torch.einsum('jii->j',metric_array)

### Heatmaps unscaled

In [None]:
left = latent[:,0].min()
right = latent[:,0].max()
bottom = latent[:,1].min()
top = latent[:,1].max()

xsize = right - left
ysize = top - bottom
xcenter = 0.5*(left + right)
ycenter = 0.5*(bottom + top)

In [None]:
linsize = 200
#xsize = 8
#ysize = 10
#xcenter = 0.0
#ycenter = -1.0

import torch.func as TF
grid_on_ls = RR.make_grid(linsize,xsize=xsize,ysize=ysize,xcenter=xcenter,ycenter=ycenter)
metric_on_grid = RR.metric_jacfwd_vmap(grid_on_ls,function=decoder)
metric_det_on_grid = torch.det(metric_on_grid)
metric_trace_on_grid = TF.vmap(torch.trace)(metric_on_grid)
curv_on_the_grid = RR.Sc_jacfwd_vmap(grid_on_ls, function = decoder)

In [None]:
plt.hist(curvature_array, bins = 60)
plt.show()

In [None]:
plt.hist(curv_on_the_grid.detach(), bins = 200)
plt.show()

In [None]:
#xcenter = 0.0 
#ycenter = 0.0
xshift = 0.0
yshift = 0.0
numticks = 5
if set_name == "Synthetic":
    tick_decimals = 2
else:
    tick_decimals = 1
plt.rcParams.update({'font.size': 16})

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15,12),dpi=300)

fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5*xsize, xcenter + 0.5*xsize, numticks) 
yticks = np.linspace(ycenter - 0.5*ysize, ycenter + 0.5*ysize, numticks)

xtick_labels = (xticks+xshift).tolist()
ytick_labels = (yticks+yshift).tolist()

xtick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels ]
ytick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks)*(linsize-1)

im1 = ax1.imshow(abs(curv_on_the_grid.detach().reshape(linsize,linsize)),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.LogNorm())
fig.colorbar(im1,ax = ax1, shrink = 1, label = "curvature abs value")
ax1.set_title("Absolute value of scalar curvature")

im2 = ax2.imshow(curv_on_the_grid.detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=abs(0.01*curv_on_the_grid.mean()).item()))
fig.colorbar(im2,ax = ax2, shrink = 1, label = "curvature")
ax2.set_title("Scalar curvature")

im3 = ax3.imshow((torch.sqrt(metric_det_on_grid)).detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im3,ax = ax3, shrink = 1, label = "$\sqrt{det(G)}$")
ax3.set_title("$\sqrt{det(G)}$")

im4 = ax4.imshow((0.5*(metric_trace_on_grid)).detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im4, ax = ax4, shrink = 1, label = "0.5$\cdot$tr(G)")
ax4.set_title("0.5$\cdot$tr(G)")

axs = (ax1, ax2, ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)

if violent_saving == True:
    plt.savefig(f'{Path_pictures}/heatmaps_not_scaled.pdf',bbox_inches='tight',format='pdf')
plt.show()

### scalar curvature

In [None]:
max_curvature = curv_on_the_grid.max().item()
min_curvature = curv_on_the_grid.min().item()
linthresh_curvature = 0.01*abs(curv_on_the_grid.mean()).item()
linthresh_curvature

max_abs_curvature = abs(curv_on_the_grid).max().item()
min_abs_curvature = 0.01*abs(curv_on_the_grid).mean().item()

In [None]:

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2, figsize=(15,12),dpi=300)

fig.tight_layout(pad=2.0)

xticks = np.linspace(xcenter - 0.5*xsize, xcenter + 0.5*xsize, numticks) 
yticks = np.linspace(ycenter - 0.5*ysize, ycenter + 0.5*ysize, numticks)

xtick_labels = (xticks+xshift).tolist()
ytick_labels = (yticks+yshift).tolist()

xtick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in xtick_labels]
ytick_labels = [ '%.{0}f'.format(tick_decimals) % elem for elem in ytick_labels]

ticks_places = np.linspace(0, 1, numticks)*(linsize-1)


ax1.title.set_text("Absolute value of scalar curvature")
p1 = ax1.scatter( latent[:,0], latent[:,1], c=abs(curvature_array), 
                 alpha=1, s = size_of_points, marker='o', 
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.LogNorm(vmin = min_abs_curvature, 
                                                vmax = max_abs_curvature))
fig.colorbar(p1,label="curvature abs value")

ax2.title.set_text("Absolute value of scalar curvature overall")
im1 = ax2.imshow(abs(curv_on_the_grid.detach().reshape(linsize,linsize)),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.LogNorm(vmin = min_abs_curvature, 
                                                  vmax = max_abs_curvature))
fig.colorbar(im1,ax = ax2, shrink = 1, label = "curvature abs value")
ax1.set_title("Absolute value of scalar curvature")

ax3.title.set_text("Scalar curvature")
p2 = ax3.scatter( latent[:,0], latent[:,1], c=curvature_array, 
                 alpha=1, s = size_of_points, marker='o', 
                 edgecolor='none', cmap='jet',
                 norm=matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                   vmin = min_curvature, 
                                                   vmax = max_curvature))
fig.colorbar(p2,label="curvature")

ax4.title.set_text("Scalar curvature overall")
im2 = ax4.imshow(curv_on_the_grid.detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                   vmin = min_curvature, 
                                                   vmax = max_curvature))
fig.colorbar(im2,ax = ax4, shrink = 1, label = "curvature")
ax4.set_title("Scalar curvature overall")

axs = (ax1, ax3)
for ax in axs:
    ax.set_ylim(bottom,top)
    ax.set_xlim(left,right)
    ax.set_xticks(list(map(float, xtick_labels)), labels = xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels = ytick_labels)

axs = (ax2, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/curvature_heatmaps.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
fig, ((ax1,ax3),(ax2,ax4))= plt.subplots(ncols=2,nrows=2,figsize = (15,12),dpi=300)

fig.tight_layout(pad=2.0)

ax1.title.set_text("$\sqrt{det(G)}$")
p = ax1.scatter( latent[:,0], latent[:,1],
                c=torch.sqrt(abs(det_array)), alpha=1, s = size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=metric_det_on_grid.max().sqrt().item())
fig.colorbar(p,label="$\sqrt{det(G)}$")
ax2.title.set_text("0.5$\cdot$tr(G)")
q = ax2.scatter( latent[:,0], latent[:,1], 
                c=0.5*(trace_array), alpha=1, s= size_of_points, 
                marker='o', edgecolor='none', cmap='jet',
                vmax=0.5*metric_trace_on_grid.max().item())
fig.colorbar(q,label="0.5$\cdot$tr(G)")

im3 = ax3.imshow((torch.sqrt(metric_det_on_grid)).detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None)
fig.colorbar(im3,ax = ax3, shrink = 1, label = "$\sqrt{det(G)}$")
ax3.set_title("$\sqrt{det(G)}$")

im4 = ax4.imshow((0.5*(metric_trace_on_grid)).detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",norm = None,
                 vmax=0.5*metric_trace_on_grid.max().item())
fig.colorbar(im4, ax = ax4, shrink = 1, label = "0.5$\cdot$tr(G)")
ax4.set_title("0.5$\cdot$tr(G)")

axs = (ax3, ax4)
for ax in axs:
    ax.set_xticks(ticks_places,labels = xtick_labels)
    ax.set_yticks(ticks_places,labels = ytick_labels)

axs = (ax1, ax2)
for ax in axs:
    ax.set_ylim(bottom,top)
    ax.set_xlim(left,right)
    ax.set_xticks(list(map(float, xtick_labels)), labels = xtick_labels)
    ax.set_yticks(list(map(float, ytick_labels)), labels = ytick_labels)

if violent_saving == True:
    #plt.savefig(f'{Path_pictures}/metric_det_trace.eps',bbox_inches='tight',format='eps')
    plt.savefig(f'{Path_pictures}/metric_det_trace.pdf',bbox_inches='tight',format='pdf')
plt.show()

### Merge pdfs

In [None]:
from pypdf import PdfMerger

In [None]:
#build_report = True
if build_report == True:
    pdfs = [f'{Path_pictures}/losses.pdf', f'{Path_pictures}/9losses.pdf', f'{Path_pictures}/init_colors_recon_loss.pdf', f'{Path_pictures}/curvature_heatmaps.pdf', f'{Path_pictures}/metric_det_trace.pdf']
    pdfs = [f'{Path_pictures}/losses.pdf', f'{Path_pictures}/9losses.pdf', f'{Path_pictures}/init_colors_recon_loss.pdf', f'{Path_pictures}/curvature_heatmaps.pdf', f'{Path_pictures}/metric_det_trace.pdf']

    merger = PdfMerger()

    for pdf in pdfs:
        merger.append(pdf)

    merger.write(f"{Path_pictures}/report_{save_weight_name}.pdf")
    merger.close()


### Latex plots

In [None]:
plt.rcParams.update({'font.size': 16})

size_of_points = 40
fig, ax1 = plt.subplots(ncols=1, nrows=1,figsize=(9,9),dpi=300)

fig.tight_layout(pad=2.0)

#ax00.title.set_text("AE latent space")
if set_name == "Synthetic":
    p = ax1.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap=discrete_cmap(k, "jet"))
    #fig.colorbar(p,label="initial color", ticks=(np.arange(k)))
else:
    p = ax1.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, s = size_of_points, marker='o', edgecolor='none', cmap='jet')
    #fig.colorbar(p,label="initial color")

ax1.set_ylim(bottom,top)
ax1.set_xlim(left,right)
ax1.set_xticks(list(map(float, xtick_labels)), labels = xtick_labels)
ax1.set_yticks(list(map(float, ytick_labels)), labels = ytick_labels)

if violent_saving == True:
    fig.savefig(f'{Path_pictures}/latex_ls.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
plt.rcParams.update({'font.size': 16})

fig, ax2 = plt.subplots(ncols=1, nrows=1,figsize=(9,9),dpi=300)

#ax2.title.set_text("Scalar curvature overall")
im2 = ax2.imshow(curv_on_the_grid.detach().reshape(linsize,linsize),
                 origin="lower",cmap="jet",
                 norm = matplotlib.colors.SymLogNorm(linthresh=linthresh_curvature,
                                                   vmin = min_curvature, 
                                                   vmax = max_curvature))
#vmax = 1e+2))
cbar = fig.colorbar(im2,ax = ax2, shrink = 0.8, label = "curvature")

cbar.ax.tick_params(rotation=0)

new_cbar_ticks = np.delete(cbar.get_ticks(),np.where((abs(cbar.get_ticks()) <= linthresh_curvature)&(cbar.get_ticks()!=0)))
new_cbar_ticks

cbar.ax.set_yticks(new_cbar_ticks)

ax2.set_xticks(ticks_places,labels = xtick_labels)
ax2.set_yticks(ticks_places,labels = ytick_labels)

if violent_saving == True:
    fig.savefig(f'{Path_pictures}/latex_curvature_heatmap.pdf',bbox_inches='tight',format='pdf')
plt.show()

# VII. Level sets of Gaussians

In [None]:
# Extract rotation matrices \phi_j and shifts y_j 
# from the set construction
phi = my_dataset.rotations
shifts = my_dataset.shifts

### Distances to means of Gaussians

In [None]:
data_for_plot = test_data

latent = encoder(data_for_plot[:][0].squeeze()).detach()
labels = data_for_plot[:][1]
int_labels = labels.to(int)
init_data = data_for_plot[:][0]
centers = []

for label in int_labels:
    centers.append(shifts[label])
centers_tensor = torch.from_numpy(np.array(centers).squeeze())
distances = torch.norm(init_data-centers_tensor,dim=1)

In [None]:
plt.figure(figsize=(9,9),dpi=400)
plt.scatter( latent[:,0], latent[:,1],s=40, c=distances, alpha=0.5, marker='o', edgecolor='none', cmap='jet')
# use for logscale: norm=matplotlib.colors.LogNorm()
#plt.colorbar(label="Distance to cloud center",orientation='vertical',shrink = 0.7)
#plt.title(f"ReLU-sine AE latent space for the \n {set_name} dataset")
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/distance_to_means_heatmap.pdf',bbox_inches='tight',format='pdf')
plt.show()

### 3 different colormaps with cbars

In [None]:
import pandas as pd
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 24
latent_labels_distances = torch.cat((latent,labels.unsqueeze(1),distances.unsqueeze(1)),dim=1)
my_dataframe = pd.DataFrame(latent_labels_distances)
cmaps = ["jet","hsv","twilight"]
#cmaps = ["jet","plasma","twilight"]
#cmaps = ["jet","turbo","hsv"]
colorbar_locations = ["right","bottom","left"]
colorbar_orientations = ["vertical","horizontal","vertical"]
colorbar_shrinks = [0.5,0.5,0.5]
colorbar_anchors = [(0.5,0.75),(0.75,0.5),(0.5,0.5)]

fig, ax = plt.subplots(figsize=(9,9),dpi=400)
#plt.title("AE latent space for the Synthetic dataset")
for plane_idx in range(k):
    # d is the number of the last column. It contains labels, i.e. colors
    results_df = my_dataframe.loc[my_dataframe[d] == plane_idx]
    #select all columns but the labeling color
    latent_points_in_plane = torch.tensor(results_df.loc[:,results_df.columns!=d].values)
    p = ax.scatter( latent_points_in_plane[:,0], latent_points_in_plane[:,1], c=latent_points_in_plane[:,2], alpha=0.5, marker='o', edgecolor='none', cmap=cmaps[plane_idx])
    fig.colorbar(p, label=f"Distance to the center of cloud {plane_idx}", orientation=colorbar_orientations[plane_idx],shrink = colorbar_shrinks[plane_idx],location = colorbar_locations[plane_idx],pad = 0.05, anchor = colorbar_anchors[plane_idx])
if violent_saving == True:
    fig.savefig(f'{Path_pictures}/distance_to_means_3heatmaps_withcbars.pdf',bbox_inches='tight',format='pdf')
fig.show()

### 3 colormaps no cbar

In [None]:
size_of_points

In [None]:
import pandas as pd
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 24
latent_labels_distances = torch.cat((latent,labels.unsqueeze(1),distances.unsqueeze(1)),dim=1)
my_dataframe = pd.DataFrame(latent_labels_distances)
cmaps = ["jet","hsv","twilight"]

fig, ax = plt.subplots(figsize=(9,9),dpi=300)
#plt.title("AE latent space for the Synthetic dataset")
for plane_idx in range(k):
    # d is the number of the last column. It contains labels, i.e. colors
    results_df = my_dataframe.loc[my_dataframe[d] == plane_idx]
    #select all columns but the labeling color
    latent_points_in_plane = torch.tensor(results_df.loc[:,results_df.columns!=d].values)
    p = ax.scatter( latent_points_in_plane[:,0], latent_points_in_plane[:,1], 
                   c=latent_points_in_plane[:,2], alpha=0.5, marker='o', 
                   s= size_of_points, edgecolor='none', cmap=cmaps[plane_idx])
ax.set_ylim(bottom,top)
ax.set_xlim(left,right)
ax.set_xticks(list(map(float, xtick_labels)), labels = xtick_labels)
ax.set_yticks(list(map(float, ytick_labels)), labels = ytick_labels)
if violent_saving == True:
    fig.savefig(f'{Path_pictures}/distance2means_synthetic_curv_w={curv_w}.pdf',bbox_inches='tight',format='pdf')
fig.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl

fig = plt.figure()
ax = fig.add_axes([0.05, 0.80, 0.9, 0.1])

cb = mpl.colorbar.ColorbarBase(ax, orientation='horizontal', 
                               cmap='jet',)

plt.savefig(f'{Path_pictures}/just_colorbar.pdf', bbox_inches='tight')

Disc of circles

In [None]:
import matplotlib.pyplot as plt
import torch

num_circles = 5 # circles per plane
numpoints = 100 #points per circle
maxrad = 3
#radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))
radius_array = maxrad*np.linspace(0,1,num_circles)

plt.title( "Canonical version of disc" )
theta_array  = torch.linspace(0, 1-1/100, 100)
x = torch.cos( 2*torch.pi*theta_array )
y = torch.sin( 2*torch.pi*theta_array )
for r in radius_array:
    r = r.item() # extracting the value of r
    plt.scatter( r*x, r*y, c=theta_array, marker='.', cmap='hsv', alpha=0.5*(2-r/maxrad) )
# end for 
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/Canonical_disk.eps',bbox_inches='tight',format='eps')
plt.show()

### Disk with colormap by polar angle

In [None]:
#plane_idx = 0
#num_circles = 40 # circles per plane
#numpoints = 100 #points per circle
#maxrad = 3
#radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in \n the latent space in plane # {plane_idx} with {str_lambda_curv}={curv_w}''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        #print("theta:",theta)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        #print("x:",x_array)
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        #print("y",y_array)
        circle = torch.cat((x_array,y_array),dim=-1)
        #circle = design_circle(100, rad=(j+1)*2)
        #plot.scatter(x_array,y_array)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        #plt.scatter(circle_in_D[:,0],circle_in_D[:,1])
        encoded_circle = encoder(circle_in_D).detach()
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1])
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],c=theta, marker='.', cmap='hsv', alpha=0.5*(2-rad/maxrad))
        plt.grid(True)
        
        #print(rad)
    ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
    if violent_saving == True:
        plt.savefig(f'{Path_pictures}/circle_in_plane#{plane_idx}_byangle.eps',bbox_inches='tight',format='eps')
    plt.show()

In [None]:
#plane_idx = 0
num_circles = 5 # circles per plane
numpoints = 100 #points per circle
maxrad = 3

radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in \n the latent space in plane # {plane_idx} with $\lambda_{{curv}}=${curv_w}''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        encoded_circle = encoder(circle_in_D).detach()
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],s=15)
        plt.grid(True)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
    if violent_saving == True:
        plt.savefig(f'{Path_pictures}/circle_in_plane#{plane_idx}_bycolor.eps',bbox_inches='tight',format='eps')
    plt.show()

In [None]:
#plane_idx = 0
num_circles = 5 # circles per plane
numpoints = 100 #points per circle
maxrad = 3*math.sqrt(var_class)

radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    #plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space with \n penalty on frobenius norm of $G-I$ equal to  {curv_w}''')
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        encoded_circle = encoder(circle_in_D).detach()
        
        # color by rad
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],marker='.',cmap='jet',s=20)
        #print(3*rad*np.ones(numpoints))
        # colorby polar angle
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1],c=theta, marker='.', cmap='hsv', alpha=1-rad/maxrad)
        plt.grid(True)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/circles_color_by_radius.png',bbox_inches='tight',format='eps')
plt.show()

In [None]:
radius_array

In [None]:
for plane_idx in range(k):
    #plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space with \n penalty on frobenius norm of $G-I$ equal to  {curv_w}''')
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        encoded_circle = encoder(circle_in_D).detach()
        
        # color by rad
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1])
        # colorby polar angle
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],c=theta, marker='.', cmap='hsv', alpha=1-rad/maxrad)
        plt.grid(True)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/circles_color_by_angle.eps',bbox_inches='tight',format='eps')
plt.show()

### 3 ideal circles (for Aniti days)

In [None]:
num_circles = 7 # circles per plane
numpoints = 100 #points per circle
maxrad = 3*math.sqrt(var_class)

radius_array = maxrad*(np.linspace(0,1,num_circles))
plt.rcParams.update({'font.size': 20})
for plane_idx in range(k):
    #plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space with \n penalty on frobenius norm of $G-I$ equal to  {curv_w}''')
    plt.title( f'''Canonical disks in the latent space''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        torch.manual_seed(4*plane_idx)
        circle_shifted = 1/9*(circle) + 1/2.5*(torch.randn(2))
        
        # color by rad
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1])
        # colorby polar angle
        plt.scatter(circle_shifted[:,0],circle_shifted[:,1],c=theta, marker='.', cmap='hsv', alpha=1-rad/maxrad)
        plt.grid(False)
        plt.xlim(-1.0,1.0)
        plt.ylim(-1.0,1.0)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/3canonical_disks.eps',bbox_inches='tight',format='eps')
plt.show()

# VIII. Histograms of metric eigenvalues

In [None]:
#linsize = 300 #linear size of the grid

# on train_data

samples_over_latent_space = (encoder(train_data[:][0]).detach()).squeeze()

#visualize samples
plt.title("Train dataset in latent space")
plt.scatter(samples_over_latent_space[:,0],samples_over_latent_space[:,1],marker=".")
plt.show()

In [None]:
metric_at_samples = RR.metric_jacfwd_vmap(samples_over_latent_space,function=decoder).detach()
eigenvalues = torch.linalg.eigvalsh(metric_at_samples)

In [None]:
min_eigenvalues = eigenvalues[:,0]
max_eigenvalues = eigenvalues[:,1]
all_eigenvalues = torch.cat((min_eigenvalues,max_eigenvalues),dim = 0)

In [None]:
metric_trace = min_eigenvalues+max_eigenvalues
metric_det = min_eigenvalues*max_eigenvalues

In [None]:
import math
plt.hist(np.array(metric_det),bins=round(math.sqrt(metric_det.shape[0])))
plt.title("Metric determinants histogram over train data points")
plt.show()

In [None]:
import math
plt.hist(np.array(metric_trace),bins=round(math.sqrt(metric_trace.shape[0])))
plt.title("Metric traces histogram over train data points")
plt.show()