# I. Hyperpameters: set and learning

In [None]:
import os
#experiment_name = "vae_klw=1e-5"
#experiment_name = "swissroll_curv_w=1e-3"
experiment_name = "current_experiment"
#experiment_name = "1Gaussians_in_D=3_lr=4e-5"

violent_saving = True # if False it will not save plots
# here you can choose a path for saving the pictures 
Path_pictures = f"/home/alazarev/CodeProjects/Experiments/{experiment_name}"
#os.mkdir(Path_pictures) # needs to be commented once the folder for plots is created

In [None]:
# Hyperparameters for dataset
set_name = "Synthetic"
#where_to_compute_curv = "random"
where_to_compute_curv = "batch"

#set_name = "Swissroll"
sr_noise = 0.05
sr_numpoints = 18000 #k*n

D = 784       #dimension
#D = 3 # for swissroll
d = 2         # latent space dimension
k = 3         # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
shift_class = 0
var_class = 1 # variation of each Gaussian
intercl_var = 0.1 # this creates a Gaussian, 
# i.e.random shift 
# proportional to the value of intercl_var
# initially 0.1

In [None]:
#klw = 0 # AE mode on

klw = 5e-4 #VAE mode on

In [None]:
# Hyperparameters for data loaders
batch_size  = 32 # was 32 initially
split_ratio = 0.2

# Set manual seed for reproducibility
# torch.manual_seed(0)

In [None]:
mse_w = 1.0
curv_w = 0.0 #weight on curvature
compute_curvature = False

### Define the loss function
#loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr         = 4e-5 #initially 4e-5 for synthetic
momentum   = 0.8 #initially 0.8
num_epochs = 10 #initially 40
batches_per_plot = 600 #initially 200 

### Set the random seed for reproducible results
# torch.manual_seed(0)

### Imports

In [None]:
# Minimal imports
import math
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# adding path to the set generating package
import sys
sys.path.append('../') # have to go 1 level up
import ricci_regularization as RR
#import torchvision


import sklearn
from sklearn import datasets

# I*. Choosing Dataset

In [None]:
if set_name == "Synthetic":
    # Generate dataset 

    # old style
    # train_dataset = ricci_regularization.generate_dataset(D, k, n, shift_class=shift_class, intercl_var=intercl_var)

    # via classes
    torch.manual_seed(0) # reproducibility
    my_dataset = RR.SyntheticDataset(k=k,n=n,d=d,D=D,
                                        shift_class=shift_class,
                                        var_class = var_class, 
                                        intercl_var=intercl_var)

    train_dataset = my_dataset.create
elif set_name == "Swissroll":
    train_dataset =  sklearn.datasets.make_swiss_roll(n_samples=sr_numpoints, noise=sr_noise)
    sr_points = torch.from_numpy(train_dataset[0]).to(torch.float32)
    #sr_points = torch.cat((sr_points,torch.zeros(sr_numpoints,D-3)),dim=1)
    sr_colors = torch.from_numpy(train_dataset[1]).to(torch.float32)
    from torch.utils.data import TensorDataset
    train_dataset = TensorDataset(sr_points,sr_colors) 

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])
test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
# test_data[:][0] will give the vectors of data without labels from the test part of the dataset

In [None]:
train_data[:][0].shape

# II. AE declatation and initialization

### Declaration

In [None]:
# Check if the GPU is available
cuda_on = torch.cuda.is_available()
if cuda_on:
    device  = torch.device("cuda") 
else :
    device = torch.device("cpu")
print(f'Selected device: {device}')

In [None]:
# sine AE
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, hidden_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        out = self.activation(out)
        return out
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        self.activation = torch.sin
        #self.activation = torch.nn.ReLU()
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        #out = torch.sigmoid(y)
        return out

In [None]:
"""
# initial structure
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, hidden_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        out = self.activation(out)
        return out
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        out = self.activation(out)
        #out = torch.sigmoid(y)
        return out
"""

In [None]:
"""
# linear map
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        out = self.linear1(x)
        #out = self.activation(out)
        return out
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(hidden_dim, output_dim)
    def forward(self, z):
        z = self.linear1(z)
        #z = torch.nn.functional.relu(self.linear1(z))
        #z = torch.sigmoid(self.linear2(z))
        #return z.reshape((-1, 1, 28, 28))
        return z
"""

In [None]:
class VariationalEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, cuda=True):
        super(VariationalEncoder, self).__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, hidden_dim)
        self.linear3 = nn.Linear(512, hidden_dim)
        
        self.N = torch.distributions.Normal(0, 1)
        if cuda:
            self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
            self.N.scale = self.N.scale.cuda()
        self.kl = 0
    
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        #x = torch.nn.functional.relu(self.linear1(x))
        x = torch.sin(self.linear1(x)) 
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z

### Initialization 

In [None]:
# initially D=784, d=2
# AE/VAE switch
if klw > 0:
    encoder = VariationalEncoder(input_dim=784, hidden_dim=d, cuda=cuda_on)
else:
    encoder = Encoder(input_dim=D, hidden_dim=d)
decoder = Decoder(hidden_dim=d, output_dim=D)

# ClassicalAE
#encoder = Encoder(input_dim=D, hidden_dim=d)
#decoder = Decoder(hidden_dim=d, output_dim=D)

params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optimizer = torch.optim.RMSprop(params_to_optimize, lr=lr, momentum=momentum, weight_decay=0.0)

#optimizer = torch.optim.Adam(params_to_optimize, lr=lr,weight_decay=0.0)

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

### loading weights

In [None]:
"""
PATH_enc = '../nn_weights/encoder_sin_curv=10.pt'
encoder.load_state_dict(torch.load(PATH_enc))
encoder.eval()
PATH_dec = '../nn_weights/decoder_sin_curv=10.pt'
decoder.load_state_dict(torch.load(PATH_dec))
decoder.eval()
"""

### choice of curvature functional

In [None]:

def Func(encoded_data):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=decoder)
    det_on_data = torch.det(metric_on_data)
    Sc_on_data = RR.Sc_jacfwd_vmap(encoded_data,
                                           function=decoder)
    N = metric_on_data.shape[0]
    Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*torch.square(Sc_on_data)).sum()
    return Integral_of_Sc

"""
# minimizing |g-I|_F
def Func(encoded_data):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=decoder)
    N = metric_on_data.shape[0]
    func = (1/N)*(metric_on_data-torch.eye(d)).norm(dim=(1,2)).sum()
    return func
"""

# III. Plotting tools

In [None]:
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:
"""
    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)
"""

In [None]:
def point_plot(encoder, data, batch_idx, colormap = 'jet',draw_grid = True,figsize = (8, 6)):

    labels = data[:][1]
    data   = data[:][0]

    # Encode
    encoder.eval()
    with torch.no_grad():
        data = data.view(-1,D) # reshape the img
        data = data.to(device)
        encoded_data = encoder(data)

    # Record codes
    latent = encoded_data.cpu().numpy()
    labels = labels.numpy()

    #Plot
    plt.figure(figsize=figsize)

    if set_name == "Swissroll":
        plt.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=colormap)
    else:
        plt.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=discrete_cmap(k, colormap))
        #plt.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, marker='o')
        plt.colorbar(ticks=range(k),orientation='vertical',shrink = 0.7)
    
    plt.title( f'''Latent space for test data in AE at batch {batch_idx}''')
    axes = plt.gca()
    plt.grid(draw_grid)
    
    return plt

In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(train_loader) )
print( "batch size:", batch_size )
print( "product: ", len(train_loader)*batch_size )
print( "-- To be compared to:", (1.0-split_ratio)*n*k)


In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(test_loader) )
print( "batch size:", batch_size )
print( "product: ", len(test_loader)*batch_size )
print( "-- To be compared to:", split_ratio*n*k)

# IV. Training

In [None]:
num_plots = 0 # to enumerate the plots
batch_idx = 0
batches_per_epoch = len(train_loader)

mse_loss = []
kl_loss = []
curv_loss = []
test_mse_loss_list = []
test_curv_loss_list = []

# to iterate though the batches of test data 
# simoultanuousely with train data
iter_test_loader = iter(test_loader)
      
for epoch in range(num_epochs):

   # Set train mode for both the encoder and the decoder
   encoder.train()
   decoder.train()
   
   
   # Iterate the dataloader: no need  for the label
   # values, this is unsupervised learning
   for image_batch, _ in train_loader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
      #shaping the images properly
      image_batch = image_batch.view(-1,D)
      # Move tensor to the proper device
      image_batch = image_batch.to(device)
      # True batch size
      true_batch_size = image_batch.shape[0]

      optimizer.zero_grad()
      
      # Front-propagation
      # -- Encode data
      encoded_data = encoder(image_batch)
      # -- Decode data
      decoded_data = decoder(encoded_data)
      # --Evaluate loss
      mse_loss_batch = torch.sum( (decoded_data-image_batch)**2 )/true_batch_size
      
      if compute_curvature == True:
         if where_to_compute_curv == "batch":
            curvature_train_batch = Func(encoded_data)
         elif where_to_compute_curv == "random":
            curvature_train_batch = Func(2*torch.rand(batch_size,2)-1)
      else:
         curvature_train_batch = 0.0
      
      loss = mse_w*mse_loss_batch + curv_w*curvature_train_batch
      
      # if VAE mode is on
      if klw > 0:
          kl_loss_batch = encoder.kl 
          loss += klw*kl_loss_batch
          kl_loss.append(kl_loss_batch.data)
      else:
          kl_loss_batch = 0.0

      # Backward pass
      loss.backward()
      optimizer.step()
      # Print batch loss
      #print('\t MSE loss per batch (single batch): %f' % (mse_loss_batch.data))
      #print('\t Total loss per batch (single batch): %f' % (loss.data))

      if batch_idx % len(test_loader):
         iter_test_loader = iter(test_loader)
      test_images = next(iter_test_loader)[0].view(-1,D).to(device)
      encoded_test_data = encoder(test_images)
      decoded_test_data = decoder(encoded_test_data)

      # True test_batch size
      true_test_batch_size = test_images.shape[0]
      with torch.no_grad():
         test_mse_loss = torch.sum( (decoded_test_data - test_images)**2 )/true_test_batch_size
         test_mse_loss_list.append(test_mse_loss.detach().cpu().numpy())
         if compute_curvature == True:
            test_curv_loss = Func(encoded_test_data)
            test_curv_loss_list.append(test_curv_loss.detach().cpu().numpy())
         # end if
      # end with
      #print('\t test MSE loss per batch (single batch): %f' % (test_mse_loss.data))
      #print('\t partial train loss (single batch): {:.6} \t curv_loss {:.6} \t mse {:.6}'.format(loss.data, new_loss, only_mse.data))
      
      mse_loss.append(float(mse_loss_batch.detach().cpu().numpy()))
      if compute_curvature == True:
         curv_loss.append(float(curvature_train_batch.detach().cpu().numpy()))

      # Plot and compute test loss      
      if (batch_idx % batches_per_plot == 0):
         #test loss

         #plotting
         plot = point_plot(encoder, test_data, batch_idx)
         if violent_saving == True:
            plot.savefig('../plots/pointplots_in_training_testdata/pp{0}.eps'.format(num_plots),format='eps')
         num_plots += 1
         plot.show()

         # plotting losses
         if batch_idx>0:
            fig, ax1 = plt.subplots()

            ax1.set_xlabel('Batches')
            ax1.set_ylabel('MSE')
            ax1.semilogy(mse_loss, label='train_MSE_loss', color='tab:orange')
            ax1.semilogy(test_mse_loss_list, label='test_MSE_loss', color='tab:red')
            
            ax1.tick_params(axis='y')
            plt.legend(loc='lower left')

            if compute_curvature == True:

               ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

               ax2.set_ylabel('Curvature')  # we already handled the x-label with ax1
               ax2.semilogy(curv_loss, label='train_Curv_loss',color='tab:olive')
               ax2.semilogy(test_curv_loss_list, label='test_Curv_loss', color='tab:green')
               
               ax2.tick_params(axis='y')
               plt.legend(loc='lower right')
            # end if
            fig.tight_layout()  # otherwise the right y-label is slightly clipped
            plt.show()
            # end if
       # end if
      batch_idx += 1
   # end for
   #print('\n EPOCH {}/{}. \t Average values over epoch:\n MSE loss: {}, Curvature loss: {}'.format(epoch + 1, num_epochs, np.mean(mse_loss[-batches_per_epoch:]),np.mean(curv_loss[-batches_per_epoch:])))
   print(f'\n EPOCH {epoch + 1}/{num_epochs}. \t Average values over epoch:\n MSE loss: {np.mean(mse_loss[-batches_per_epoch:])}, Curvature loss: {np.mean(curv_loss[-batches_per_epoch:])}, KL loss: {np.mean(kl_loss[-batches_per_epoch:])}')

### Heatmap of MSE over the datapoints

In [None]:
#choose data
data_for_plot = test_data

latent = encoder(data_for_plot[:][0].squeeze()).detach()
labels = data_for_plot[:][1]
init_data = data_for_plot[:][0]
recon_data = decoder(encoder(init_data).detach())
abs_error_array = (recon_data-init_data).squeeze()
mse_array = abs_error_array.norm(dim=1)**2
mse_array = mse_array.detach()

plt.scatter( latent[:,0], latent[:,1], c=mse_array, alpha=0.5, marker='o', edgecolor='none', cmap='jet',norm=matplotlib.colors.LogNorm())
plt.colorbar(label="squared l2 norm errors")
plt.title("Heatmap of the squared errors in l2 norm over the data\n points encoded into the latent space in logscale")
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/logscale_MSE_heatmap.eps',bbox_inches='tight',format='eps')
plt.show()

### Relative and absolute errors

In [None]:
init_data = train_data[:][0]
recon_data = decoder(encoder(init_data).detach())
abs_error_array = (recon_data-init_data).squeeze()
errors_l2 = abs_error_array.norm(dim=1)
errors_l1 = (torch.abs(abs_error_array)).sum(dim=1)

average_l1_error = errors_l1.mean()
average_l2_error = errors_l2.mean()

init_data_l1_norms = (torch.abs(init_data.squeeze())).sum(dim=1)
init_data_l2_norms = (init_data.squeeze()).norm(dim=1)

average_l1_norm = init_data_l1_norms.mean()
average_l2_norm = init_data_l2_norms.mean()

rel_l1_errors = errors_l1/init_data_l1_norms
rel_l2_errors = errors_l2/init_data_l2_norms

average_relative_error_l1 = rel_l1_errors.mean()
average_relative_error_l2 = (rel_l2_errors.mean())


print(f"average l_1 relative error is: {average_relative_error_l1*100:.4f}%")
print(f"average l_1 absolute error is: {average_l1_error}")
print(f"average l_1 norm of initial {D} dimensional data is: {average_l1_norm}")

print(f"average l_2 relative error is: {average_relative_error_l2*100:.4f}%")
print(f"average l_2 absolute error is: {average_l2_error}")
print(f"average l_2 norm of initial {D} dimensional data is: {average_l2_norm}")


In [None]:
fig, ax = plt.subplots()
plt.title("Histogram of squared euclidean norms\n of reconstruction errors over train data")
ax.hist((errors_l2**2).detach(),bins=round(math.sqrt(errors_l2.shape[0])))
#fig.text(0.0,-0.10, f"MSE:{(errors_l2**2).mean().item():.4f} \nSet params: n={n}, k={k}, d={d}, D={D}, $\sigma$={var_class}, $\sigma_{{I}}$={intercl_var}.")
fig.text(0.0,-0.10, f"MSE:{(errors_l2**2).mean().item():.4f}")
#plot = plt.show()
if violent_saving == True:
    fig.savefig(f'{Path_pictures}/reconstruction_errors.eps',bbox_inches='tight',format='eps')

In [None]:
plt.title("Histogram of squared euclidean norms \n of all points of the swiss roll")
plt.hist(train_data[:][0].norm(dim=1)**2,bins=round(math.sqrt(n)))
plt.xlabel("squared l2 norm")
plt.ylabel("number of points")
plot.show()

In [None]:
plt.rcParams.update({'font.size': 17}) # makes all fonts on the plot be 24

fig, ax1 = plt.subplots()

ax1.set_xlabel('Batches')
ax1.set_ylabel('MSE')
ax1.semilogy(mse_loss, label='train_MSE_loss', color='tab:orange')
ax1.semilogy(test_mse_loss_list, label='test_MSE_loss', color='tab:red')

ax1.tick_params(axis='y')
plt.legend(loc='lower left')

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

ax2.set_ylabel('Curvature')  # we already handled the x-label with ax1
ax2.semilogy(curv_loss, label='train_Curv_loss',color='tab:olive')
ax2.semilogy(test_curv_loss_list, label='test_Curv_loss', color='tab:green')

ax2.tick_params(axis='y')
plt.legend(loc='upper right')
fig.tight_layout()  # otherwise the right y-label is slightly clipped
#fig.text(0.05,-0.25,f'The dataset has {k*n} points originated \nby {k} Gaussian(s). \nAverage losses over the last epoch: \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}')
if set_name == "Swissroll":
    fig.text(0.05,-0.25,f"Set params: n={sr_numpoints}, noise={sr_noise}. \nAverage losses over the last epoch: \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}")
else:    
    fig.text(0.05,-0.25,f"Set params: n={n}, k={k}, d={d}, D={D}, $\sigma$={var_class}, $\sigma_{{I}}$={intercl_var}. \nAverage losses over the last epoch: \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}")
str_lambda_recon = "$\lambda_{recon}$"
str_lambda_curv = "$\lambda_{curv}$"
plt.title(f"Params: lr={lr}, batch_size={batch_size},\n {str_lambda_recon}={mse_w}, {str_lambda_curv}={curv_w}")
#plt.title("Params: lr={0}, batch_size={1},\n $\lambda_r$={2}, $\lambda_c$={3},{4}".format(lr,batch_size,mse_w,curv_w,str1))
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/losses.eps',bbox_inches='tight',format='eps')
plt.show()

### Final latent space

In [None]:
set_name

In [None]:
#plot = plt.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=discrete_cmap(k, 'jet'))
#plot.colorbar(ticks=range(k))
scale = 0.8
plot = point_plot(encoder,test_data,batch_idx,colormap='jet',draw_grid=False,figsize = (8*scale, 6*scale))
#plt.rcParams.update({'font.size': 20})
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 24
if set_name == "Swissroll":
    plot.title( f'Latent space of the AE for the swiss roll')
    #plot.text(-1.4,-1.4, f"Set params: n={sr_numpoints}, noise={sr_noise}.")
else:    
    plot.title( f'Latent space of the VAE for the\n Synthetic dataset')
    #plot.text(-1.4,-1.4, f"Set params: n={n}, k={k}, d={d}, D={D}, $\sigma$={var_class}, $\sigma_{{I}}$={intercl_var}.")
if violent_saving == True:
    plot.savefig(f'{Path_pictures}/VAE_latent_space.eps',bbox_inches='tight',format='eps')
plot.show()

# V. Level sets of Gaussians

In [None]:
# Extract rotation matrices \phi_j and shifts y_j 
# from the set construction
phi = my_dataset.rotations
shifts = my_dataset.shifts

Disc of circles

In [None]:
import matplotlib.pyplot as plt
import torch

num_circles = 5 # circles per plane
numpoints = 100 #points per circle
maxrad = 3
#radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))
radius_array = maxrad*np.linspace(0,1,num_circles)

plt.title( "Canonical version of disc" )
theta_array  = torch.linspace(0, 1-1/100, 100)
x = torch.cos( 2*torch.pi*theta_array )
y = torch.sin( 2*torch.pi*theta_array )
for r in radius_array:
    r = r.item() # extracting the value of r
    plt.scatter( r*x, r*y, c=theta_array, marker='.', cmap='hsv', alpha=0.5*(2-r/maxrad) )
# end for 
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/Canonical_disk.eps',bbox_inches='tight',format='eps')
plt.show()

### Disk with colormap by polar angle

In [None]:
#plane_idx = 0
#num_circles = 40 # circles per plane
#numpoints = 100 #points per circle
#maxrad = 3
#radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in \n the latent space in plane # {plane_idx} with {str_lambda_curv}={curv_w}''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        #print("theta:",theta)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        #print("x:",x_array)
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        #print("y",y_array)
        circle = torch.cat((x_array,y_array),dim=-1)
        #circle = design_circle(100, rad=(j+1)*2)
        #plot.scatter(x_array,y_array)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        #plt.scatter(circle_in_D[:,0],circle_in_D[:,1])
        encoded_circle = encoder(circle_in_D).detach()
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1])
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],c=theta, marker='.', cmap='hsv', alpha=0.5*(2-rad/maxrad))
        plt.grid(True)
        
        #print(rad)
    ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
    if violent_saving == True:
        plt.savefig(f'{Path_pictures}/circle_in_plane#{plane_idx}_byangle.eps',bbox_inches='tight',format='eps')
    plt.show()

In [None]:
#plane_idx = 0
num_circles = 5 # circles per plane
numpoints = 100 #points per circle
maxrad = 3

radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in \n the latent space in plane # {plane_idx} with $\lambda_{{curv}}=${curv_w}''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        encoded_circle = encoder(circle_in_D).detach()
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],s=15)
        plt.grid(True)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
    if violent_saving == True:
        plt.savefig(f'{Path_pictures}/circle_in_plane#{plane_idx}_bycolor.eps',bbox_inches='tight',format='eps')
    plt.show()

In [None]:
#plane_idx = 0
num_circles = 5 # circles per plane
numpoints = 100 #points per circle
maxrad = 3*math.sqrt(var_class)

radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    #plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space with \n penalty on frobenius norm of $G-I$ equal to  {curv_w}''')
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        encoded_circle = encoder(circle_in_D).detach()
        
        # color by rad
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],marker='.',cmap='jet',s=20)
        #print(3*rad*np.ones(numpoints))
        # colorby polar angle
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1],c=theta, marker='.', cmap='hsv', alpha=1-rad/maxrad)
        plt.grid(True)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/circles_color_by_radius.png',bbox_inches='tight',format='eps')
plt.show()

In [None]:
radius_array

In [None]:
for plane_idx in range(k):
    #plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space with \n penalty on frobenius norm of $G-I$ equal to  {curv_w}''')
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        encoded_circle = encoder(circle_in_D).detach()
        
        # color by rad
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1])
        # colorby polar angle
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],c=theta, marker='.', cmap='hsv', alpha=1-rad/maxrad)
        plt.grid(True)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
if violent_saving == True:
    plt.savefig(f'{Path_pictures}/circles_color_by_angle.eps',bbox_inches='tight',format='eps')
plt.show()

# VI. Metric

### Histogram of eigenvalues

In [None]:
N = 1000
torch.manual_seed(0)
samples_over_latent_space = torch.rand(N,2)-1.0

#visualize samples
plot.scatter(samples_over_latent_space[:,0],samples_over_latent_space[:,1])

In [None]:
metric_at_samples = RR.metric_jacfwd_vmap(samples_over_latent_space,function=decoder).detach()
eigenvalues = torch.linalg.eigvalsh(metric_at_samples)

In [None]:
min_eigenvalues = eigenvalues[:,0]
max_eigenvalues = eigenvalues[:,1]
all_eigenvalues = torch.cat((min_eigenvalues,max_eigenvalues),dim = 0)

In [None]:
import math
plt.title(f"Histogram of eigenvalues of metric $G$ evaluated \n at {N} samples uniformly distributed \n over the latent space with penalty \n {curv_w} on Frobenius norm of $G-I$")
plt.hist(all_eigenvalues, bins=round(math.sqrt(N)))
plt.xlabel("All eigenvalues of metric")
plt.show()

In [None]:
import math
plt.title(f"Histogram of minimal eigenvalues of metric $G$ evaluated at {N} \n samples uniformly distributed over the latent space ")
plt.hist(min_eigenvalues, bins=round(math.sqrt(N)))
plt.xlabel("Min eigenvalues of metric")
plt.show()

In [None]:
plt.hist(max_eigenvalues, bins=round(math.sqrt(N)))
plt.title(f"Histogram of maximal eigenvalues of metric $G$ evaluated at {N} \n samples uniformly distributed over the latent space ")
plt.xlabel("Max eigenvalues of metric")
plt.show()

In [None]:
samples_over_latent_space = (encoder(train_data[:][0]).detach()).squeeze()
samples_over_latent_space.shape

### Heatmaps of det(G) and tr(G) and histograms of errors 

In [None]:

# on grid
grid_on_ls = RR.make_grid(100,xsize=2,ysize=2,xcenter=-0.0,ycenter=0.0)
#samples_over_latent_space = grid_on_ls

# random
# samples_over_latent_space = 2*torch.rand(N,2)-1.0

# on train_data

samples_over_latent_space = (encoder(train_data[:][0]).detach()).squeeze()

#visualize samples
plt.title("Visualized samples")
plt.scatter(samples_over_latent_space[:,0],samples_over_latent_space[:,1],marker=".")
plt.show()

In [None]:
#RR.metric_jacfwd(torch.tensor([-0.2,0.0]), function=decoder)

In [None]:
metric_at_samples = RR.metric_jacfwd_vmap(samples_over_latent_space,function=decoder).detach()
eigenvalues = torch.linalg.eigvalsh(metric_at_samples)

In [None]:
min_eigenvalues = eigenvalues[:,0]
max_eigenvalues = eigenvalues[:,1]
all_eigenvalues = torch.cat((min_eigenvalues,max_eigenvalues),dim = 0)

In [None]:
metric_trace = min_eigenvalues+max_eigenvalues
metric_det = min_eigenvalues*max_eigenvalues

In [None]:
import math
plt.hist(np.array(metric_det),bins=round(math.sqrt(metric_det.shape[0])))
plt.title("Metric determinants histogram over train data points")
plt.show()

In [None]:
import math
plt.hist(np.array(metric_trace),bins=round(math.sqrt(metric_trace.shape[0])))
plt.title("Metric traces histogram over train data points")
plt.show()

In [None]:
import torch.func as TF
grid_on_ls = RR.make_grid(100,xsize=2,ysize=2,xcenter=-0.0,ycenter=0.0)
metric_on_grid = RR.metric_jacfwd_vmap(grid_on_ls,function=decoder)
metric_det_on_grid = torch.det(metric_on_grid)
metric_trace_on_grid = TF.vmap(torch.trace)(metric_on_grid)

In [None]:
RR.draw_scalar_on_grid(metric_det_on_grid.view(100,100),plot_name=f"det(G) with curv_w={curv_w}",numsteps=100,xsize=2,ysize=2,xcenter=-0.0,ycenter=0.0)

In [None]:
RR.draw_scalar_on_grid(metric_trace_on_grid.view(100,100),plot_name=f"tr(G) curv_w={curv_w}",numsteps=100,xsize=2,ysize=2,xcenter=-0.0,ycenter=0.0)

In [None]:
init_points = train_data[:][0]
#init_points

In [None]:
init_points.shape

In [None]:
errors = (decoder(encoder(init_points)) - init_points).norm(dim=(1,2))**2
errors = errors/len(errors)

In [None]:
import numpy as np
np.array(errors.detach())

In [None]:
import math
plt.hist(np.array(errors.detach()),bins=round(math.sqrt(errors.shape[0])))
plt.title("Reconstruction errors histogram")
plt.show()

# VII. Saving the model

In [None]:
"""
PATH_enc = '../nn_weights/encoder_synthetic_lr=6e-5.pt'
torch.save(encoder.state_dict(), PATH_enc)
PATH_dec = '../nn_weights/decoder_synthetic_lr=6e-5.pt'
torch.save(decoder.state_dict(), PATH_dec)
"""