# I. Hyperpameters: set and learning

In [None]:
# prerequisites
# adding path to the set generating package
import sys
sys.path.append('../') # have to go 1 level up
import ricci_regularization as RR

# Minimal imports
from torch import nn

import sklearn
from sklearn import datasets
import matplotlib
import matplotlib.pyplot as plt
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import math
import numpy as np
import os

In [None]:
OOD_regime = True
compute_curvature = True
batch_size  = 64 # was 32 initially
# curvature params
mse_w = 1.0
curv_w = 10.0 #weight on curvature
start_curv = 0 # batch index to strart curvature computation from

klw = 0 # AE mode off

# Training and plotting params
lr         = 4e-5 #initially 4e-5 for synthetic, 1e-4 for swissroll (or 2e-5 with curvature on)
momentum   = 0.8 #initially 0.8
batches_per_plot = 400 #initially 200 
split_ratio = 0.2
num_epochs = 200 

random_state = 1
seed_number = 0

# plot saving params
violent_saving = False # if False it will not save plots

#experiment_name = "swissroll_OOD_curv_w=10_sigma_ood=0.2_T_OOD=20_OOD_w=10_20epochs"
experiment_name = "MNIST"
#experiment_name = "synthetic_curv_w=1e+1_ls=R^2+OOD"
#experiment_name = f"swissroll_rs={random_state}_curv_w=100_lr=4e-6"
#experiment_name = "swissroll_curv_func_oscilations_curv_w=10_eps=0.01"
#experiment_name = "current experiment"


weights_loaded = False
model_weights_saved = True
load_weight_name = "synthetic_curv_w=1e+1_ls=R^2"
#load_weight_name = "synthetic_curv_w=1e+1_ls=R^2"
save_weight_name = experiment_name
#save_weight_name = "swissroll_curv_w=1+OOD"

# here you can choose a path for saving the pictures
if violent_saving == True: 
    Path_pictures = f"/home/alazarev/CodeProjects/Experiments/{experiment_name}"
    if os.path.exists(Path_pictures) == False:
        os.mkdir(Path_pictures) # needs to be commented once the folder for plots is created

In [None]:
# Hyperparameters for dataset
#set_name = "Synthetic"
#where_to_compute_curv = "random"
where_to_compute_curv = "batch"

set_name = "Swissroll"
sr_noise = 0.05
sr_numpoints = 18000 #k*n

#set_name = "MNIST"

#D = 784       #dimension
#D = 2
k = 3         # num of 2d planes in dim D
if set_name == "Swissroll":
    D = 3 # for swissroll
elif set_name == "Synthetic":
    D = 784
elif set_name == "MNIST":
    D = 784
    k = 10
d = 2         # latent space dimension

n = 6*(10**3) # num of points in each plane
shift_class = 0
var_class = 1 # variation of each Gaussian initially 0.1
intercl_var = 0.1 # this creates a Gaussian, 
# i.e.random shift 
# proportional to the value of intercl_var
# initially 0.1

In [None]:
# OOD sampling parameters
T_ood = 20 # 100 # period of OOD penalization
n_ood = 5 # number of OOD samples per point
sigma_ood = 5e-1 # sigma of OOD Gaussian samples: 2e-1 swissroll
N_extr = 16 # 32 batch size of extremal curvature points
r_ood = 1e-3 # 1e-2 decay factor
OOD_w = curv_w
start_ood = 400

# I*. Choosing Dataset

In [None]:
if set_name == "Synthetic":
    # Generate dataset 

    # old style
    # train_dataset = ricci_regularization.generate_dataset(D, k, n, shift_class=shift_class, intercl_var=intercl_var)

    # via classes
    #torch.manual_seed(0) # reproducibility
    my_dataset = RR.SyntheticDataset(k=k,n=n,d=d,D=D,
                                        shift_class=shift_class,
                                        var_class = var_class, 
                                        intercl_var=intercl_var)

    train_dataset = my_dataset.create
elif set_name == "Swissroll":
    train_dataset =  sklearn.datasets.make_swiss_roll(n_samples=sr_numpoints, noise=sr_noise, random_state = random_state) #random_state=1
    sr_points = torch.from_numpy(train_dataset[0]).to(torch.float32)
    #sr_points = torch.cat((sr_points,torch.zeros(sr_numpoints,D-3)),dim=1)
    sr_colors = torch.from_numpy(train_dataset[1]).to(torch.float32)
    from torch.utils.data import TensorDataset
    train_dataset = TensorDataset(sr_points,sr_colors)
elif set_name == "MNIST":
    # MNIST Dataset
    #train_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    #test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)

    # Data Loader (Input Pipeline)
    #train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True, num_workers=num_workers, pin_memory=True)
    #test_loader  = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False, num_workers=num_workers, pin_memory=True)
    train_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset  = torchvision.datasets.MNIST(root='../datasets/', train=False, download=True)

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])
test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
# test_data[:][0] will give the vectors of data without labels from the test part of the dataset

# II. AE declatation and initialization

### Declaration

In [None]:
# Check if the GPU is available
cuda_on = torch.cuda.is_available()
if cuda_on:
    device  = torch.device("cuda") 
else :
    device = torch.device("cpu")
print(f'Selected device: {device}')

In [None]:
# sine AE ls = R^2
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, hidden_dim)
        #self.activation = nn.ReLU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        return out
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        self.activation = torch.sin
        #self.activation = torch.nn.ReLU()
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        return out

In [None]:
class VariationalEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, cuda=True):
        super(VariationalEncoder, self).__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, hidden_dim)
        self.linear3 = nn.Linear(512, hidden_dim)
        
        self.N = torch.distributions.Normal(0, 1)
        if cuda:
            self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
            self.N.scale = self.N.scale.cuda()
        self.kl = 0
    
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        #x = torch.nn.functional.relu(self.linear1(x))
        x = torch.sin(self.linear1(x)) 
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z

### Initialization 

In [None]:
import time
current_time = time.monotonic_ns()
torch.manual_seed(seed_number)
#torch.manual_seed(current_time)
#print("manual seed:", current_time)
# initially D=784, d=2
# AE/VAE switch
if klw > 0:
    encoder = VariationalEncoder(input_dim=784, hidden_dim=d, cuda=cuda_on)
else:
    encoder = Encoder(input_dim=D, hidden_dim=d)
decoder = Decoder(hidden_dim=d, output_dim=D)

# ClassicalAE
#encoder = Encoder(input_dim=D, hidden_dim=d)
#decoder = Decoder(hidden_dim=d, output_dim=D)

params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]
#optimizer = torch.optim.Adam(params_to_optimize, lr=lr,weight_decay=0.0)

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

In [None]:

optimizer = torch.optim.RMSprop(params_to_optimize, lr=lr, momentum=momentum, weight_decay=0.0)

### loading weights

In [None]:
if weights_loaded==True:
    PATH_enc = f'../nn_weights/encoder_{load_weight_name}'
    encoder.load_state_dict(torch.load(PATH_enc))
    encoder.eval()
    PATH_dec = f'../nn_weights/decoder_{load_weight_name}'
    decoder.load_state_dict(torch.load(PATH_dec))
    decoder.eval()

### choice of curvature functional

In [None]:
def Func(encoded_data):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=decoder)
    det_on_data = torch.det(metric_on_data)
    Sc_on_data = RR.Sc_jacfwd_vmap(encoded_data,
                                           function=decoder)
    N = metric_on_data.shape[0]
    Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*torch.square(Sc_on_data)).sum()
    return Integral_of_Sc
"""
# minimizing |g-I|_F
def Func(encoded_data):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=decoder)
    N = metric_on_data.shape[0]
    func = (1/N)*(metric_on_data-torch.eye(d)).norm(dim=(1,2)).sum()
    return func
"""

# III. Plotting tools

In [None]:
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:
"""
    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)
"""

In [None]:
def point_plot(encoder, data, batch_idx, show_title = True, colormap = 'jet',s=40,draw_grid = True,figsize = (8, 6)):

    if set_name == "MNIST":
        data = test_data.dataset.test_data.to(torch.float32)
        labels = test_data.dataset.targets
    else:
        labels = data[:][1]
        data   = data[:][0]

    # Encode
    encoder.eval()
    with torch.no_grad():
        data = data.view(-1,D) # reshape the img
        data = data.to(device)
        encoded_data = encoder(data)

    # Record codes
    latent = encoded_data.cpu().numpy()
    labels = labels.numpy()

    #Plot
    plt.figure(figsize=figsize)

    if set_name == "Swissroll":
        plt.scatter( latent[:,0], latent[:,1],s=s, c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=colormap)
    else:
        plt.scatter( latent[:,0], latent[:,1],s=s, c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=discrete_cmap(k, colormap))
        #plt.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, marker='o')
        plt.colorbar(ticks=range(k),orientation='vertical',shrink = 0.7)
    if show_title == True:
        plt.title( f'''Latent space for test data in AE at batch {batch_idx}''')
    axes = plt.gca()
    plt.grid(draw_grid)
    
    return plt

In [None]:
def plot3losses(mse_train_list,curv_train_list,g_inv_train_list):
    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(6,18))
    
    axes[0].semilogy(mse_train_list, color = 'tab:red')
    axes[0].set_ylabel('MSE')
    
    axes[1].semilogy(curv_train_list, color = 'tab:olive')
    axes[1].set_ylabel('Curvature')
    
    axes[2].semilogy(g_inv_train_list, color = 'tab:blue')
    axes[2].set_ylabel('$\|G^{-1}\|_F$')
    for i in range(3):
        axes[i].set_xlabel('Batches')
    fig.show()
    return fig,axes

In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(train_loader) )
print( "batch size:", batch_size )
print( "product: ", len(train_loader)*batch_size )
print( "-- To be compared to:", (1.0-split_ratio)*n*k)


In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(test_loader) )
print( "batch size:", batch_size )
print( "product: ", len(test_loader)*batch_size )
print( "-- To be compared to:", split_ratio*n*k)

# IV. Training

In [None]:
num_plots = 0 # to enumerate the plots
batch_idx = 0
batches_per_epoch = len(train_loader)

#extreme_curv_points_list = []

mse_loss = []
kl_loss = []
curv_loss = []
test_mse_loss_list = []
test_curv_loss_list = []
g_inv_norm_mean_train_list = []

# initialize extreme curvature tensor
with torch.no_grad():
   extreme_curv_points_tensor = torch.rand(N_extr, d)
   extreme_curv_value_tensor = torch.zeros(N_extr)
         

# to iterate though the batches of test data 
# simoultanuousely with train data
iter_test_loader = iter(test_loader)
      
for epoch in range(num_epochs):
   #if weights_loaded == True:
   #   break
   # Set train mode for both the encoder and the decoder
   encoder.train()
   decoder.train()
   
   
   # Iterate the dataloader: no need  for the label
   # values, this is unsupervised learning
   for image_batch, _ in train_loader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
      #shaping the images properly
      image_batch = image_batch.view(-1,D)
      # Move tensor to the proper device
      image_batch = image_batch.to(device)
      # True batch size
      true_batch_size = image_batch.shape[0]

      optimizer.zero_grad()
      
      # Front-propagation
      # -- Encode data
      encoded_data = encoder(image_batch)
      # -- Decode data
      decoded_data = decoder(encoded_data)
      # --Evaluate loss
      mse_loss_batch = torch.sum( (decoded_data-image_batch)**2 )/true_batch_size
      
      if OOD_regime == True:
         # Exrteme curvature batch
         new_curv_points_tensor = encoded_data
         #print(new_curv_points_tensor.shape)
         new_curv_value_tensor = RR.Sc_jacfwd_vmap(new_curv_points_tensor,
                                                         function = decoder)
      #print(new_curv_value_tensor.shape)

      if compute_curvature == True:
         if where_to_compute_curv == "batch":
            curvature_train_batch = Func(encoded_data)
            g_inv_train_batch = torch.linalg.inv(RR.metric_jacfwd_vmap(encoded_data,function=decoder))
            g_inv_norm_train_batch = torch.linalg.matrix_norm(g_inv_train_batch)
            g_inv_norm_mean_train_batch = torch.mean(g_inv_norm_train_batch)
            #volume_form = torch.sqrt(torch.det(RR.metric_jacfwd_vmap(encoded_data, function=decoder)))
            #curvature_train_batch = (1/true_batch_size)*(torch.square(new_curv_value_tensor)*volume_form).sum()
         elif where_to_compute_curv == "random":
            curvature_train_batch = Func(2*torch.rand(batch_size,2)-1)
      else:
         curvature_train_batch = 0.0
      #if batch_idx < start_curv:
      #   loss = mse_w*mse_loss_batch
      #else:
      loss = mse_w*mse_loss_batch + curv_w*curvature_train_batch
      if OOD_regime == True:
         with torch.no_grad():
            # merge extreme points and new batch 
            extreme_curv_points_tensor = torch.cat((extreme_curv_points_tensor,
                                                   new_curv_points_tensor),dim=0)
            #print(extreme_curv_points_tensor.shape)
            extreme_curv_value_tensor = torch.cat((extreme_curv_value_tensor,
                                                   new_curv_value_tensor),dim=0)

            #print(extreme_curv_value_tensor.shape)

            # sort by curvature value points and curvature values. 
            indices = torch.argsort(extreme_curv_value_tensor)
            #print(indices)
            extreme_curv_points_tensor = torch.index_select(extreme_curv_points_tensor,dim = 0, index= indices)
            extreme_curv_value_tensor = torch.index_select(extreme_curv_value_tensor,dim = 0, index= indices)
            #print(extreme_curv_value_tensor)
            # take most N_extr//2 negative and N_extr//2 most positive
            extreme_curv_points_tensor = torch.cat((extreme_curv_points_tensor[:N_extr//2],extreme_curv_points_tensor[-N_extr//2:]),dim=0)
            extreme_curv_value_tensor = torch.cat((extreme_curv_value_tensor[:N_extr//2],extreme_curv_value_tensor[-N_extr//2:]),dim=0)
            #print("\nwe keep",extreme_curv_value_tensor.shape[0],
            #      "extreme curvature values:\n",extreme_curv_value_tensor,
            #      "\nat points:\n", extreme_curv_points_tensor)
            # multiply curv values by decay factor
            extreme_curv_value_tensor = math.exp(-r_ood)*extreme_curv_value_tensor
            # if not enough points, keep 16 of each (min and max) anyway      
            # but when OOD sampling, sample around min negative
            # and max positive. Print how many of each are used!
         # end with
         
      # OOD sampling
         if (batch_idx % T_ood == 0) & (batch_idx > start_ood):
            #optimizer.zero_grad()
            with torch.no_grad():
               #centers = extreme_curv_points_tensor.repeat(n_ood,1)
               centers = extreme_curv_points_tensor.repeat_interleave(n_ood,dim=0)
               samples_centered_at_zero = (sigma_ood**2)*torch.randn(N_extr*n_ood, d)
               OOD_batch = centers + samples_centered_at_zero
            OOD_batch.requires_grad_()
            #print(OOD_batch.shape)
            #plt.scatter(OOD_batch.detach()[:,0], OOD_batch.detach()[:,1])
            #plt.show()
            # OOD change of loss function 
            Func_val = Func(OOD_batch)
            #OOD_w = mse_loss_batch/Func_val
            print("Func of OOD points", Func_val)
            #loss = OOD_w* Func(OOD_batch)
            loss = curv_w * Func_val
         # end if   
      
      # if VAE mode is on
      if klw > 0:
          kl_loss_batch = encoder.kl 
          loss += klw*kl_loss_batch
          kl_loss.append(kl_loss_batch.data)
      else:
          kl_loss_batch = 0.0

      # Backward pass
      loss.backward()
      optimizer.step()
      # Print batch loss
      #print('\t MSE loss per batch (single batch): %f' % (mse_loss_batch.data))
      #print('\t Total loss per batch (single batch): %f' % (loss.data))

      if batch_idx % len(test_loader) == 0:
         iter_test_loader = iter(test_loader)
      test_images = next(iter_test_loader)[0].view(-1,D).to(device)
      encoded_test_data = encoder(test_images)
      decoded_test_data = decoder(encoded_test_data)

      # True test_batch size
      true_test_batch_size = test_images.shape[0]
      with torch.no_grad():
         test_mse_loss = torch.sum( (decoded_test_data - test_images)**2 )/true_test_batch_size
         test_mse_loss_list.append(test_mse_loss.detach().cpu().numpy())
         if compute_curvature == True:
            test_curv_loss = Func(encoded_test_data)
            test_curv_loss_list.append(test_curv_loss.detach().cpu().numpy())
         # end if
      # end with
      #print('\t test MSE loss per batch (single batch): %f' % (test_mse_loss.data))
      #print('\t partial train loss (single batch): {:.6} \t curv_loss {:.6} \t mse {:.6}'.format(loss.data, new_loss, only_mse.data))
      
      mse_loss.append(float(mse_loss_batch.detach().cpu().numpy()))
      if compute_curvature == True:
         curv_loss.append(float(curvature_train_batch.detach().cpu().numpy()))
         g_inv_norm_mean_train_list.append(g_inv_norm_mean_train_batch.item())

      # Plot and compute test loss      

      if (batch_idx % batches_per_plot == 0):
         #test loss

         #plotting
         plot = point_plot(encoder, test_data, batch_idx)
         if violent_saving == True:
            plot.savefig('../plots/pointplots_in_training_testdata/pp{0}.eps'.format(num_plots),format='eps')
         num_plots += 1
         plot.show()

         # plotting losses
         if batch_idx>0:
            if compute_curvature==True:
               plot3losses(mse_loss,curv_loss,g_inv_norm_mean_train_list)
            else:
               fig, ax1 = plt.subplots()

               ax1.set_xlabel('Batches')
               ax1.set_ylabel('MSE')
               ax1.semilogy(mse_loss, label='train_MSE_loss', color='tab:orange')
               ax1.semilogy(test_mse_loss_list, label='test_MSE_loss', color='tab:red')
               
               ax1.tick_params(axis='y')
               plt.legend(loc='lower left')
               fig.tight_layout()  # otherwise the right y-label is slightly clipped
               plt.show()
         """
         if batch_idx>0:
            fig, ax1 = plt.subplots()

            ax1.set_xlabel('Batches')
            ax1.set_ylabel('MSE')
            ax1.semilogy(mse_loss, label='train_MSE_loss', color='tab:orange')
            ax1.semilogy(test_mse_loss_list, label='test_MSE_loss', color='tab:red')
            
            ax1.tick_params(axis='y')
            plt.legend(loc='lower left')

            if compute_curvature == True:

               ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

               ax2.set_ylabel('Curvature')  # we already handled the x-label with ax1
               ax2.semilogy(curv_loss, label='train_Curv_loss',color='tab:olive')
               ax2.semilogy(test_curv_loss_list, label='test_Curv_loss', color='tab:green')
               
               ax2.tick_params(axis='y')
               plt.legend(loc='lower right')
            # end if
            fig.tight_layout()  # otherwise the right y-label is slightly clipped
            plt.show()
            # end if
         # end if
      # end if
       """
      batch_idx += 1
   # end for
   #print('\n EPOCH {}/{}. \t Average values over epoch:\n MSE loss: {}, Curvature loss: {}'.format(epoch + 1, num_epochs, np.mean(mse_loss[-batches_per_epoch:]),np.mean(curv_loss[-batches_per_epoch:])))
   print(f'\n EPOCH {epoch + 1}/{num_epochs}. \t Average values over epoch:\n MSE loss: {np.mean(mse_loss[-batches_per_epoch:])}, Curvature loss: {np.mean(curv_loss[-batches_per_epoch:])}, KL loss: {np.mean(kl_loss[-batches_per_epoch:])}')

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()   

# V. Saving the model

In [None]:
model_weights_saved = True

if (model_weights_saved == True):
    PATH_enc = f'../nn_weights/encoder_{save_weight_name}'
    print( "Saving encoder weights to", PATH_enc)
    torch.save(encoder.state_dict(), PATH_enc)
    PATH_dec = f'../nn_weights/decoder_{save_weight_name}'
    print( "Saving decoder weights to", PATH_dec)
    torch.save(decoder.state_dict(), PATH_dec)

# VI. Plots

### determination coefficient and training curves

In [None]:
if set_name == "Swissroll":
    points_tensor = torch.tensor(sr_points)
    cov_matrix = torch.cov(points_tensor.T)
    print("Covariance matrix:\n", cov_matrix)
    mean = points_tensor.mean(dim=0)
    print("Mean vector:", mean)
    R_squared = 1 - ((decoder(encoder(sr_points)).data-sr_points).norm()**2)/((sr_points-mean).norm()**2)
    print("Determination coef R^2:", R_squared)
    MSE = (1/sr_numpoints)*(decoder(encoder(sr_points)).data-sr_points).norm()**2
    print("MSE:", MSE)
    print("tr(Q):", cov_matrix.trace())
    print("MSE/tr(Q):", MSE/cov_matrix.trace())
    print("R^2 = 1-MSE/tr(Q):", 1-MSE/cov_matrix.trace())
elif set_name =="Synthetic":
    points_tensor = train_dataset[:][0]
    cov_matrix = torch.cov(points_tensor.T)
    #print("Covariance matrix:\n", cov_matrix)
    mean = points_tensor.mean(dim=0)
    #print("Mean vector:", mean)
    R_squared = 1 - ((decoder(encoder(points_tensor)).data-train_dataset[:][0]).norm()**2)/((train_dataset[:][0]-mean).norm()**2)
    print("Determination coef R^2:", R_squared)
elif set_name == "MNIST":
    input = test_data.dataset.data.to(torch.float32).view(-1,D)
    target = decoder(encoder(test_data.dataset.data.to(torch.float32).view(-1,D)))
    #mean = input.mean(dim=0)
    mse = F.mse_loss(input, target, reduction="mean")
    cov_matrix = torch.cov(input.T)
    R_squared = 1 - mse/cov_matrix.trace()
    print("Determination coef R^2:", R_squared.item())

In [None]:
#from torcheval.metrics import R2Score
#metric = R2Score()
#input = torch.tensor([[0, 2], [1, 6]])
#target = torch.tensor([[0, 1], [2, 5]])
#metric.update(input[:2], target[:2])
#metric.compute()

In [None]:
#curv_loss_on_train_data = Func(encoder(train_data[:][0]))

In [None]:
plt.rcParams.update({'font.size': 17}) # makes all fonts on the plot be 24

fig, ax1 = plt.subplots()

ax1.set_xlabel('Batches')
ax1.set_ylabel('MSE')
ax1.semilogy(mse_loss, label='train_MSE_loss', color='tab:orange')
ax1.semilogy(test_mse_loss_list, label='test_MSE_loss', color='tab:red')

ax1.tick_params(axis='y')
plt.legend(loc='lower left')

if compute_curvature == True:
    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
    #ax2.set_ylabel('G-I loss')
    ax2.set_ylabel('Curvature')  # we already handled the x-label with ax1
    ax2.semilogy(curv_loss, label='train_Curv_loss',color='tab:olive')
    ax2.semilogy(test_curv_loss_list, label='test_Curv_loss', color='tab:green')
    ax2.tick_params(axis='y')
    plt.legend(loc='upper right')
fig.tight_layout()  # otherwise the right y-label is slightly clipped
#fig.text(0.05,-0.25,f'The dataset has {k*n} points originated \nby {k} Gaussian(s). \nAverage losses over the last epoch: \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}')
# Average losses over the last epoch
if set_name == "Swissroll":
    fig.text(0.05,-0.25,f"Set params: n={sr_numpoints}, noise={sr_noise}. \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}, \n$R^2=${R_squared:.4f}")
else:    
    fig.text(0.05,-0.25,f"Set params: n={n}, k={k}, d={d}, D={D}, $\sigma$={var_class}, $\sigma_{{I}}$={intercl_var}. \nMSE loss: {np.mean(mse_loss[-batches_per_epoch:]):.3f}, \nCurvature loss: {np.mean(curv_loss[-batches_per_epoch:]):.3f}, \n $R^2=${R_squared:.4f}, \n Curvature loss over train dataset: {curv_loss_on_train_data:.4f}" )
if OOD_regime == True:
    fig.text(0.05, -0.35, f"OOD params: T_ood = {T_ood}, n_ood = 5, OOD_w={OOD_w}, \nsigma_ood = {sigma_ood}, N_extr = {N_extr}, r_ood = {r_ood}")
str_lambda_recon = "$\lambda_{recon}$"
str_lambda_curv = "$\lambda_{curv}$"
plt.title(f"Params: lr={lr}, batch_size={batch_size},\n {str_lambda_recon}={mse_w}, {str_lambda_curv}={curv_w}")
#plt.title("Params: lr={0}, batch_size={1},\n $\lambda_r$={2}, $\lambda_c$={3},{4}".format(lr,batch_size,mse_w,curv_w,str1))
#violent_saving = True
if (violent_saving == True): 
#& (weights_loaded == False):
    plt.savefig(f'{Path_pictures}/losses.pdf',bbox_inches='tight',format='pdf')
#plt.savefig(f'{Path_pictures}/losses.pdf',bbox_inches='tight',format='pdf')
plt.show()

In [None]:
#convert to numpy
curv_np = np.array(curv_loss)
g_inv_np = np.array(g_inv_norm_mean_train_list)

In [None]:
if compute_curvature == True:
    fig, axes = plot3losses(mse_loss,curv_loss,g_inv_norm_mean_train_list)
    if violent_saving==True:
        plt.savefig(f'{Path_pictures}/3losses.pdf',bbox_inches='tight',format='pdf')

In [None]:
from scipy import signal

In [None]:
def plot9losses(mse_train_list,curv_train_list,g_inv_train_list):
    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(18,18))
    
    win50 = signal.windows.hann(50)
    win200 = signal.windows.hann(200)

    axes[0,0].semilogy(mse_train_list, color = 'tab:red')
    axes[0,0].set_ylabel('MSE')

    axes[0,1].semilogy(signal.convolve(mse_train_list, win50, mode='same') / sum(win50), color = 'tab:red')
    #axes[0,1].set_ylabel('MSE')

    axes[0,2].semilogy(signal.convolve(mse_train_list, win200, mode='same') / sum(win200), color = 'tab:red')
    #axes[0,2].set_ylabel('MSE')
    
    axes[1,0].semilogy(curv_train_list, color = 'tab:olive')
    axes[1,0].set_ylabel('Curvature')

    axes[1,1].semilogy(signal.convolve(curv_train_list, win50, mode='same') / sum(win50), color = 'tab:olive')
    #axes[1,1].set_ylabel('Curvature')

    axes[1,2].semilogy(signal.convolve(curv_train_list, win200, mode='same') / sum(win200), color = 'tab:olive')
    #axes[1,2].set_ylabel('Curvature')
    
    axes[2,0].semilogy(g_inv_train_list, color = 'tab:blue')
    axes[2,0].set_ylabel('$\|G^{-1}\|_F$')

    axes[2,1].semilogy(signal.convolve(g_inv_train_list, win50, mode='same') / sum(win50), color = 'tab:blue')
    #axes[2,1].set_ylabel('$\|G^{-1}\|_F$')

    axes[2,2].semilogy(signal.convolve(g_inv_train_list, win200, mode='same') / sum(win200), color = 'tab:blue')
    #axes[2,2].set_ylabel('$\|G^{-1}\|_F$')

    for i in range(3):
        for j in range(3):
            if i==2:
                axes[i,j].set_xlabel('Batches')
    fig.show()
    return fig,axes

In [None]:
if compute_curvature == True:
    fig, axes = plot9losses(mse_loss,curv_loss,g_inv_norm_mean_train_list)
    if violent_saving == True:
        plt.savefig(f'{Path_pictures}/9losses.pdf',bbox_inches='tight',format='pdf')

In [None]:
import json
if violent_saving == True:
    with open(f"{Path_pictures}/mse_loss", "w") as mse_file:
        json.dump(mse_loss, mse_file)
    with open(f"{Path_pictures}/curv_loss", "w") as curv_file:
        json.dump(curv_loss, curv_file)
    with open(f"{Path_pictures}/g_inv", "w") as g_inv_file:
        json.dump(g_inv_norm_mean_train_list, g_inv_file)