# Imports

In [None]:
# prerequisites
%matplotlib inline
from functorch import jacrev,jacfwd
import sklearn
from sklearn import datasets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import math
import numpy as np
from tqdm.notebook import tqdm
import json
import os

## Hyperparameters

In [None]:
experiment_name = "MNIST_torus_AE"
experiment_number = 0
violent_saving = False # if False it will not save plots
Path_experiments = "/home/alazarev/CodeProjects/Experiments/"
Path_pictures = f"/home/alazarev/CodeProjects/Experiments/{experiment_name}/experiment{experiment_number}"
if os.path.exists(Path_pictures) == False:
    os.mkdir(Path_pictures) # needs to be commented once the folder for plots is created
Path_weights = "/home/alazarev/CodeProjects/ricci_regularization/"
#set_name = "Swissroll"
#set_name = "Synthetic"
set_name = "MNIST"

d = 2         # latent space dimension
weights_loaded = False

In [None]:
mse_w = 1e4
unif_w = 0 # 4e1
curv_w = 0 # if 0 curvature is not computed
compute_curvature = False

### Define an optimizer (both for the encoder and the decoder!)
lr         = 1e-3
momentum   = 0.8
num_epochs = 10

# Hyperparameters for data loaders
batch_size  = 256 # was 16 initially
split_ratio = 0.2

# Set manual seed for reproducibility
# torch.manual_seed(0)

In [None]:
import json

# Sample dictionary
hyperparameters = {
    "experiment_name": experiment_name,
    "experiment_number":experiment_number,
    "set_name": set_name,
    "learning_rate": lr,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "mse_w": mse_w,
    "unif_w": unif_w,
    "curv_w": curv_w,
    "Path_pictures": Path_pictures,
    "Path_weights": Path_weights
}

# Save dictionary to JSON file
with open(f'{Path_experiments}json_files/hyperparameters_exp{experiment_number}.json', 'w') as json_file:
    json.dump(hyperparameters, json_file, indent=4)

## Set uploading 

In [None]:
import sys
sys.path.append('../') # have to go 1 level up
import ricci_regularization as RR

In [None]:
# Number of workers in DataLoader
num_workers = 10

if set_name == "MNIST":
    D = 784
    k = 10 # number of classes
    #MNIST_SIZE = 28
    # MNIST Dataset
    train_dataset = datasets.MNIST(root='../datasets/', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset  = datasets.MNIST(root='../datasets/', train=False, transform=transforms.ToTensor(), download=False)

    # Data Loader (Input Pipeline)
    #train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    #test_loader  = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
elif set_name == "Synthetic":
    D = 784       #dimension
    k = 3         # num of 2d planes in dim D
    n = 6*(10**3) # num of points in each plane
    shift_class = 0.0
    var_class = 1.0
    intercl_var = 0.1 # this has to be greater than 0.04
    # this creates a gaussian, 
    # i.e.random shift 
    # proportional to the value of intercl_var
    # Generate dataset
    # via classes
    torch.manual_seed(0) # reproducibility
    my_dataset = RR.SyntheticDataset(k=k,n=n,d=d,D=D,
                                        shift_class=shift_class, intercl_var=intercl_var, var_class=var_class)

    train_dataset = my_dataset.create
elif set_name == "Swissroll":
    D = 3
    sr_noise = 1e-6
    sr_numpoints = 18000 #k*n
    train_dataset =  sklearn.datasets.make_swiss_roll(n_samples=sr_numpoints, noise=sr_noise)
    sr_points = torch.from_numpy(train_dataset[0]).to(torch.float32)
    #sr_points = torch.cat((sr_points,torch.zeros(sr_numpoints,D-3)),dim=1)
    sr_colors = torch.from_numpy(train_dataset[1]).to(torch.float32)
    from torch.utils.data import TensorDataset
    train_dataset = TensorDataset(sr_points,sr_colors)

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# VAE structure

In [None]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        # Non-linearity
        self.non_linearity = torch.sin
        self.non_linearity2 = torch.cos # should this not be vice versa??
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc3 = nn.Linear(h_dim2, z_dim)
        # decoder part
        # Double dimension as circle is mimicked using sin and cos charts
        self.fc4 = nn.Linear(2*z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = self.non_linearity(self.fc1(x))
        h = self.non_linearity(self.fc2(h))
        h = self.fc3(h)
        # Concatenate sin and cos non-linearities
        # Warning: Done along dimension 1, as dimension 0 is the batch dimension
        #h = torch.cat( (self.non_linearity(h), self.non_linearity2(h)), 1)
        h = torch.cat( (self.non_linearity2(h), self.non_linearity(h)), 1)
        return h # Latent variable z, Wannabe uniform on the circle
    def encoder2lifting(self, x):
        h = self.non_linearity(self.fc1(x))
        h = self.non_linearity(self.fc2(h))
        h = self.fc3(h)
        # Concatenate sin and cos non-linearities
        # Warning: Done along dimension 1, as dimension 0 is the batch dimension
        #h = torch.cat( (self.non_linearity(h), self.non_linearity2(h)), 1)
        # cosphi,sinphi
        h = torch.cat( (self.non_linearity2(h), self.non_linearity(h)), 1) 
        cosphi = h[:, 0:d] #*0.99 could work
        sinphi = h[:, d:2*d]
        #phi = torch.acos(cosphi)*torch.sgn(torch.asin(sinphi))
        phi = torch.acos(cosphi)*torch.sign(sinphi)
        return phi
    def encoder_torus(self, x):   
        #This is a mapping to a feature space so it would be wrong to use it
        h = self.non_linearity(self.fc1(x))
        h = self.non_linearity(self.fc2(h))
        h = self.fc3(h)
        return h
        
    def decoder(self, z):
        #h = self.non_linearity( math.pi*z + self.decoderBias ) # Expects 2pi periodic non-linearity to create torus topology
        h = z
        h = self.non_linearity( self.fc4(h))
        h = self.non_linearity( self.fc5(h))
        return self.non_linearity( self.fc6(h) )
    def decoder_torus(self, z):
        h = z
        h = torch.cat( (self.non_linearity2(h), self.non_linearity(h)), 1)
        h = self.non_linearity( self.fc4(h))
        h = self.non_linearity( self.fc5(h))
        return self.non_linearity( self.fc6(h) )
    
    def forward(self, x):
        z = self.encoder(x.view(-1, D))
        return self.decoder(z), z

# old model
vae = VAE(x_dim=D, h_dim1= 512, h_dim2=256, z_dim=d)
#changed model
#vae = VAE(x_dim=D, h_dim1= 3, h_dim2=2, z_dim=d)

if torch.cuda.is_available():
    vae.cuda()

### Loading the saved weights

In [None]:

if weights_loaded == True:
    PATH_vae = f'../nn_weights/exp{experiment_number}.pt'
    vae.load_state_dict(torch.load(PATH_vae))
    vae.eval()


## Optimizer and loss function


In [None]:
optimizer = optim.Adam(vae.parameters(),lr=lr)
# return reconstruction error + KL divergence losses
def loss_functionm_old(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='mean')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

# MSE is computed with mean reduction in order for MSE and KLD to be of the same order
# Inputs:
#   recon_data: reconstructed data via decoder
#   data: original data
#   z: latent variable
def loss_function_old2(recon_data, data, z, mu, Sigma):
    MSE = F.mse_loss(recon_data, data.view(-1, D), reduction='mean')
    KLD = 0.5 * ( torch.trace(Sigma) + mu.norm().pow(2) - d - Sigma.logdet() )
    return (MSE + KLD)*1e4

def curv_func(encoded_data, function):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=function)
    det_on_data = torch.det(metric_on_data)
    Sc_on_data = RR.Sc_jacfwd_vmap(encoded_data,
                                           function=function)
    N = metric_on_data.shape[0]
    Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*torch.square(Sc_on_data)).sum()
    return Integral_of_Sc
    
# Loss = MSE + Penalization + curv_loss
#  where the penalization uses modulis of Fourier modes, of the empirical distribution.
#  This requires batch size to be in the range of CLT.
#
# Inputs:
#   recon_data: reconstructed data via decoder
#   data: original data
#   z: latent variable
def loss_function(recon_data, data, z):
    MSE = F.mse_loss(recon_data, data.view(-1, D), reduction='mean')
    #
    # Splits sines and cosines
    z_sin = z[:, 0:d]
    z_cos = z[:, d:2*d]
    #
    # Compute empirical first mode
    mode1 = torch.mean( z, dim = 0)
    mode1 = torch.sum( mode1*mode1 )
    #
    # Compute empirical second mode
    mode2_1 = torch.mean( 2*z_cos*z_cos-1, dim = 0)
    mode2_1 = torch.sum( mode2_1*mode2_1)
    mode2_2 = torch.mean( 2*z_sin*z_cos, dim = 0)
    mode2_2 = torch.sum( mode2_2*mode2_2 )
    mode2 = mode2_1 + mode2_2
    #
    penalization = mode1 + mode2
    #print("penalization: ", penalization)
    if curv_w>0:
        encoded_points = vae.encoder2lifting(data.view(-1, D)).detach()
        curv_loss = curv_func(encoded_points,function=vae.decoder_torus)   
    else:
        if compute_curvature == True:
            encoded_points = vae.encoder2lifting(data.view(-1, D)).detach()
            curv_loss = curv_func(encoded_points,function=vae.decoder_torus)
        else:
            curv_loss = torch.zeros(1)
    #print("curvature loss:", curv_loss)
    return MSE, penalization, curv_loss

## Plotting tools

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0)
      #encoder.eval()
      #decoder.eval()
      with torch.no_grad():
         #rec_img  = decoder(encoder(img))
         rec_img  = decoder(encoder(img.reshape(1,D))).reshape(1,28,28)
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()   
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)
def plot3losses(mse_train_list,uniform_train_list,curv_train_list):
    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(6,18))
    
    axes[0].semilogy(mse_train_list, color = 'tab:red')
    axes[0].set_ylabel('MSE')
    
    axes[1].semilogy(uniform_train_list, color = 'tab:olive')
    axes[1].set_ylabel('Uniform loss')
    
    axes[2].semilogy(curv_train_list, color = 'tab:blue')
    axes[2].set_ylabel('Curvature')
    for i in range(3):
        axes[i].set_xlabel('Batches')
    #fig.show()
    plt.show()
    return fig,axes

In [None]:
# From https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/5
def cov(m, rowvar=True, inplace=False):
    '''Estimate a covariance matrix given data.

    Covariance indicates the level to which two variables vary together.
    If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
    then the covariance matrix element `C_{ij}` is the covariance of
    `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.

    Args:
        m: A 1-D or 2-D array containing multiple variables and observations.
            Each row of `m` represents a variable, and each column a single
            observation of all those variables.
        rowvar: If `rowvar` is True, then each row represents a
            variable, with observations in the columns. Otherwise, the
            relationship is transposed: each column represents a variable,
            while the rows contain observations.

    Returns:
        The covariance matrix of the variables.
    '''
    if m.dim() > 2:
        raise ValueError('m has more than 2 dimensions')
    if m.dim() < 2:
        m = m.view(1, -1)
    if not rowvar and m.size(0) != 1:
        m = m.t()
    # m = m.type(torch.double)  # uncomment this line if desired
    fact = 1.0 / (m.size(1) - 1)
    if inplace:
        m -= torch.mean(m, dim=1, keepdim=True)
    else:
        m = m - torch.mean(m, dim=1, keepdim=True)
    mt = m.t()  # if complex: mt = m.t().conj()
    return fact * m.matmul(mt).squeeze()

# Looking under the hood

In [None]:
"""
x = torch.rand(1,784)
vae = VAE(x_dim=D, h_dim1= 512, h_dim2=256, z_dim=d)
optimizer = optim.Adam(vae.parameters(),lr=lr)
vae.train()
for (data,label) in tqdm(train_loader):
    optimizer.zero_grad()
    data = data.view(-1, D).cpu()
    recon_data = vae(data)[0]
    #recon_data = vae.decoder_torus(vae.encoder2lifting(data))
    mse_loss = F.mse_loss(recon_data, data, reduction='mean')
    curv_loss = curv_func(vae.encoder2lifting(data).detach(),function=vae.decoder_torus) # use detach() to fix the points
    
    myloss = 1e4*mse_loss + 1e3*curv_loss
    #myloss = 1e3*curv_func(vae.encoder2lifting(data.view(-1,D)).detach(),function=vae.decoder_torus)
    #myloss = vae.encoder2lifting(data.view(-1,D)).norm()
    #myloss = vae.encoder2lifting(x).norm()
    #print("\n 4d repr:", vae.encoder(x))
    #myloss = 1e3*curv_func(vae.encoder_torus(data.view(-1,D)),function=vae.decoder_torus)
    
    myloss.backward()
    optimizer.step()
    print(myloss)
    plot_ae_outputs(vae.encoder2lifting,vae.decoder_torus)
"""

In [None]:
def train(epoch, mse_loss_array=[], uniform_loss_array=[], curvatue_loss_array = []):
    batch_idx = 0
    vae.train()
    train_loss = 0
    print("Epoch %d"%epoch)
    t = tqdm( train_loader, position=0 )
    for (data, labels) in t:
        #data = data.cuda()
        #print(data.shape)
        data = data.cpu()
        optimizer.zero_grad()
        # Forward
        recon_batch, z = vae(data)
        mse_loss, uniform_loss, curvature_loss = loss_function(recon_batch, data, z)
        loss = mse_w*mse_loss + unif_w*uniform_loss + curv_w*curvature_loss
        #loss = mse_w*mse_loss + unif_w*uniform_loss 
        #loss = curv_w*curvature_loss
        print(f"batch:{batch_idx}, MSE:{mse_loss}, Uniform:{uniform_loss}, Curvature:{curvature_loss}.\n")
        # Backpropagate
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        mse_loss_array.append(mse_loss.item())
        uniform_loss_array.append(uniform_loss.item())
        curvatue_loss_array.append(curvature_loss.item())
        #loss_array.append(loss.item())
        # Progress bar
        t.set_description_str(desc="Average train loss: %.6f"% (train_loss / len(train_loader.dataset)) )
        #if (batch_idx % 100 == 0):
        #    plot3losses(mse_loss_array,uniform_loss_array,curvatue_loss_array)
        batch_idx += 1
    # end for 
    
    
    return mse_loss_array, uniform_loss_array, curvatue_loss_array

def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        t = tqdm( test_loader, desc="Test", position=1 )
        for data, _ in t:
            #data = data.cuda()
            data = data.cpu()
            recon, z = vae(data)
            # sum up batch loss
            test_loss += loss_function(recon, data, z).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Total test set loss: {:.4f}'.format(test_loss))
    print("")

# Training

In [None]:
mse_loss_array=[]
uniform_loss_array=[]
curvatue_loss_array = []
# Launch
for epoch in range(1, num_epochs + 1):
  mse_loss_array,uniform_loss_array,curvatue_loss_array = train(epoch, mse_loss_array, uniform_loss_array, curvatue_loss_array)
  plot3losses(mse_loss_array,uniform_loss_array,curvatue_loss_array)
  plot_ae_outputs(vae.encoder2lifting,vae.decoder_torus)
  #test()

## Saving the model

In [None]:
PATH_vae = f'../nn_weights/exp{experiment_number}.pt'
torch.save(vae.state_dict(), PATH_vae)

## Losses plot

In [None]:
import os
Path_pictures = f"/home/alazarev/CodeProjects/Experiments/{experiment_name}/experiment{experiment_number}"
if os.path.exists(Path_pictures) == False:
    os.mkdir(Path_pictures) # needs to be commented once the folder for plots is created

In [None]:
# loss ploting
fig,axes = plot3losses(mse_loss_array,uniform_loss_array,curvatue_loss_array)
fig.savefig(f"{Path_pictures}/losses_exp{experiment_number}.pdf",bbox_inches='tight',format="pdf")

## Torus latent space

In [None]:
#inspiration for vae.encoder2lifting
"""
def circle2anglevectorized(zLatentTensor,d = d):
    cosphi = zLatentTensor[:, 0:d]
    sinphi = zLatentTensor[:, d:2*d]
    phi = torch.acos(cosphi)*torch.sgn(torch.asin(sinphi))
    return phi
"""


In [None]:
#zlist = []
colorlist = []
enc_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( train_loader, position=0 ):
#for (data, labels) in train_loader:
    input_dataset_list.append(data)
    recon_dataset_list.append(vae(data)[0])
    #zlist.append(vae(data)[1])
    enc_list.append(vae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()
#assert torch.equal(enc,enc_tensor)

In [None]:
#angleLatentviatorch = circle2anglevectorized(zLatent_tensor)/math.pi
#plt.scatter(angleLatentviatorch[:,0],angleLatentviatorch[:,1], c=labels, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
#enc = vae.encoder2lifting(train_dataset.data.reshape(-1,784).to(dtype = torch.float32)).detach()
#enc = vae.encoder2lifting(train_dataset.data.reshape(-1,784)/256).detach() # this works!!!
#enc = vae.encoder2lifting(train_dataset.data.reshape(-1,784)/256).detach()
#enc = vae.encoder_torus(train_dataset.data.reshape(-1,784)/256).detach()
#plt.scatter(enc[:,0],enc[:,1], c=train_dataset.targets, marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.figure(figsize=(8, 6))
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=color_array, marker='o', edgecolor='none', cmap=discrete_cmap(k, 'jet'))
plt.colorbar(ticks=range(k))
plt.grid(True)
plt.savefig(f"{Path_pictures}/latent_space_exp{experiment_number}.pdf",bbox_inches='tight',format="pdf")
#plt.show()

In [None]:
"""
RG = RR.RiemannianGeometry(latent_space_dim=4,function=vae.decoder,AD_method=jacfwd,eps=0.01)
RG.Sc(torch.rand(1,4))
"""