# I. Train and test datasets

In [None]:
# Hyperparameters for dataset
D = 784       #dimension
d = 2         # latent space dimension
k = 3         # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
shift_class = 0
intercl_var = 0.1 # this creates a gaussian, 
# i.e.random shift 
# proportional to the value of intercl_var

In [None]:
# Hyperparameters for data loaders
batch_size  = 32 # was 16 initially
split_ratio = 0.2

# Set manual seed for reproducibility
# torch.manual_seed(0)

In [None]:
mse_w = 1.0
curv_w = 1.0 #weight on curvature

### Define the loss function
#loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr         = 2e-6
momentum   = 0.8
num_epochs = 2
batches_per_plot = 50 #initially 200 

### Set the random seed for reproducible results
# torch.manual_seed(0)

### Initialize the two networks
d = 2

In [None]:
# adding path to the set generating package
import sys
sys.path.append('../') # have to go 1 level up

import ricci_regularization

In [None]:
# Minimal imports
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

import ricci_regularization as RR
#import torchvision

# Generate dataset 

# old style
# train_dataset = ricci_regularization.generate_dataset(D, k, n, shift_class=shift_class, intercl_var=intercl_var)

# via classes
torch.manual_seed(0) # reproducibility
my_dataset = RR.SyntheticDataset(k=k,n=n,d=d,D=D,
                                    shift_class=shift_class, intercl_var=intercl_var)

train_dataset = my_dataset.create

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# test_data[:][0] will give the vectors of data without labels from the test part of the dataset

# II. Declaration of AE

In [None]:
# Check if the GPU is available
cuda_on = torch.cuda.is_available()
if cuda_on:
    device  = torch.device("cuda") 
else :
    device = torch.device("cpu")
print(f'Selected device: {device}')

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, hidden_dim)
        #self.activation = nn.ELU()
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        out = self.activation(out)
        return out

In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        self.activation = torch.sin
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        y = self.linear4(y)
        out = self.activation(y)
        #out = torch.sigmoid(y)
        return out

In [None]:
#alternative 2layer simplified decoder
"""
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(hidden_dim, 512)
        self.linear2 = nn.Linear(512, output_dim)
        
    def forward(self, z):
        z = torch.nn.functional.relu(self.linear1(z))
        z = torch.sigmoid(self.linear2(z))
        #return z.reshape((-1, 1, 28, 28))
        return z
"""

# III. Initialization 

In [None]:
# ClassicalAE
encoder = Encoder(input_dim=784, hidden_dim=d)
decoder = Decoder(hidden_dim=d, output_dim=784)

params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optimizer = torch.optim.RMSprop(params_to_optimize, lr=lr, momentum=momentum, weight_decay=0.0)

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

# III*. Adding Curvature functional

In [None]:
import sys
sys.path.append('../') # have to go 1 level up
import ricci_regularization as RR

In [None]:
#RR.metric_jacfwd_vmap(torch.rand(10,2),function=decoder).shape[0]

### choice of curvature functional

In [None]:
def Func(encoded_data):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=decoder)
    det_on_data = torch.det(metric_on_data)
    Sc_on_data = RR.Sc_jacfwd_vmap(encoded_data,
                                           function=decoder)
    N = metric_on_data.shape[0]
    Integral_of_Sc = (1/N)*(torch.sqrt(det_on_data)*torch.square(Sc_on_data)).sum()
    return Integral_of_Sc

"""
# minimizing |g-I|_F
def Func(encoded_data):
    metric_on_data = RR.metric_jacfwd_vmap(encoded_data,
                                           function=decoder)
    N = metric_on_data.shape[0]
    func = (1/N)*(metric_on_data-torch.eye(d)).norm(dim=(1,2)).sum()
    return func
"""

# III**. Plotting tools

In [None]:
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)

In [None]:
def point_plot(encoder, data, batch_idx):

    labels = data[:][1]
    data   = data[:][0]

    # Encode
    encoder.eval()
    with torch.no_grad():
        data = data.view(-1,28*28) # reshape the img
        data = data.to(device)
        encoded_data = encoder(data)

    # Record codes
    latent = encoded_data.cpu().numpy()
    labels = labels.numpy()

    #Plot
    plt.figure(figsize=(8, 6))
    plt.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=discrete_cmap(k, 'jet'))
    plt.title( f'''Latent space for test data in AE at batch {batch_idx}''')
    plt.colorbar(ticks=range(k))
    axes = plt.gca()
    plt.grid(True)
    
    return plt

In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(train_loader) )
print( "batch size:", batch_size )
print( "product: ", len(train_loader)*batch_size )
print( "-- To be compared to:", (1.0-split_ratio)*n*k)


In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(test_loader) )
print( "batch size:", batch_size )
print( "product: ", len(test_loader)*batch_size )
print( "-- To be compared to:", split_ratio*n*k)

# IV.Training

In [None]:
num_plots = 0 # to enumerate the plots
batch_idx = 0

mse_loss = []
curv_loss = []
test_mse_loss_list = []
test_curv_loss_list = []

# to iterate though the batches of test data 
# simoultanuousely with train data
iter_test_loader = iter(test_loader)
      
for epoch in range(num_epochs):

   # Set train mode for both the encoder and the decoder
   encoder.train()
   decoder.train()
   
   
   # Iterate the dataloader: no need  for the label
   # values, this is unsupervised learning
   for image_batch, _ in train_loader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
      #shaping the images properly
      image_batch = image_batch.view(-1,28*28)
      # Move tensor to the proper device
      image_batch = image_batch.to(device)
      # True batch size
      true_batch_size = image_batch.shape[0]

      optimizer.zero_grad()
      
      # Front-propagation
      # -- Encode data
      encoded_data = encoder(image_batch)
      # -- Decode data
      decoded_data = decoder(encoded_data)
      # --Evaluate loss
      mse_loss_batch = torch.sum( (decoded_data-image_batch)**2 )/true_batch_size
      curvature_train_batch = Func(encoded_data)
      
      loss = mse_w*mse_loss_batch + curv_w*curvature_train_batch
      
      # Backward pass
      loss.backward()
      optimizer.step()
      # Print batch loss
      #print('\t MSE loss per batch (single batch): %f' % (mse_loss_batch.data))
      #print('\t Total loss per batch (single batch): %f' % (loss.data))
      
      if batch_idx % len(test_loader):
         iter_test_loader = iter(test_loader)
      test_images = next(iter_test_loader)[0].view(-1,D).to(device)
      encoded_test_data = encoder(test_images)
      decoded_test_data = decoder(encoded_test_data)

      # True test_batch size
      true_test_batch_size = test_images.shape[0]
      with torch.no_grad():
         test_mse_loss = torch.sum( (decoded_test_data - test_images)**2 )/true_test_batch_size
         test_mse_loss_list.append(test_mse_loss.detach().cpu().numpy())
         test_curv_loss = Func(encoded_test_data)
         test_curv_loss_list.append(test_curv_loss.detach().cpu().numpy())
      #print('\t test MSE loss per batch (single batch): %f' % (test_mse_loss.data))
      #print('\t partial train loss (single batch): {:.6} \t curv_loss {:.6} \t mse {:.6}'.format(loss.data, new_loss, only_mse.data))
      
      mse_loss.append(float(mse_loss_batch.detach().cpu().numpy()))
      curv_loss.append(float(curvature_train_batch.detach().cpu().numpy()))

      # Plot and compute test loss      
      if (batch_idx % batches_per_plot == 0):
         #test loss

         #plotting
         plot = point_plot(encoder, test_data, batch_idx)
         plot.savefig('../plots/pointplots_in_training_testdata/pp{0}.png'.format(num_plots))
         num_plots += 1
         plot.show()

         # plotting losses
         if batch_idx>0:
            fig, ax1 = plt.subplots()

            ax1.set_xlabel('Batches')
            ax1.set_ylabel('MSE')
            ax1.plot(mse_loss, label='train_MSE_loss', color='tab:orange')
            ax1.plot(test_mse_loss_list, label='test_MSE_loss', color='tab:red')
            
            ax1.tick_params(axis='y')
            plt.legend(loc='lower left')

            ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

            ax2.set_ylabel('Curvature')  # we already handled the x-label with ax1
            ax2.semilogy(curv_loss, label='train_Curv_loss',color='tab:olive')
            ax2.semilogy(test_curv_loss_list, label='test_Curv_loss', color='tab:green')
            
            ax2.tick_params(axis='y')
            plt.legend(loc='lower right')
            fig.tight_layout()  # otherwise the right y-label is slightly clipped
            plt.show()
            # end if
       # end if
      batch_idx += 1
   # end for
   print('\n EPOCH {}/{} \t MSE loss: {}, Curvature loss: {}'.format(epoch + 1, num_epochs, np.mean(mse_loss),np.mean(curv_loss)))

In [None]:
plt.rcParams.update({'font.size': 17}) # makes all fonts on the plot be 24

fig, ax1 = plt.subplots()

ax1.set_xlabel('Batches')
ax1.set_ylabel('MSE')
ax1.plot(mse_loss, label='train_MSE_loss', color='tab:orange')
ax1.plot(test_mse_loss_list, label='test_MSE_loss', color='tab:red')

ax1.tick_params(axis='y')
plt.legend(loc='lower left')

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

ax2.set_ylabel('Curvature')  # we already handled the x-label with ax1
ax2.semilogy(curv_loss, label='train_Curv_loss',color='tab:olive')
ax2.semilogy(test_curv_loss_list, label='test_Curv_loss', color='tab:green')


ax2.tick_params(axis='y')
plt.legend(loc='upper right')
fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.title("Params: lr={0}, batch_size={1},\n $\lambda_r$={2}, $\lambda_c$={3}".format(lr,batch_size,mse_w,curv_w))
plt.show()
#plt.savefig('../plots/losses_2layers_decoder.png'.format(n))

In [None]:
"""
REDO LATER $\lambda_{recon}$ inside a string
tex_string = "$\lambda_{recon}$"
s = f''' toto {tex_string}'''
print( s )
"""

# Level sets of gaussians

In [None]:
# Extract rotation matrices \phi_j and shifts y_j 
# from the set construction
phi = my_dataset.rotations
shifts = my_dataset.shifts

Disc of circles

In [None]:
import matplotlib.pyplot as plt
import torch
plt.title( "Canonical version of disc" )
radius_array = torch.sqrt( torch.linspace(0, 1, 40) )
theta_array  = torch.linspace(0, 1-1/100, 100)
x = torch.cos( 2*torch.pi*theta_array )
y = torch.sin( 2*torch.pi*theta_array )
for r in radius_array:
    r = r.item() # extracting the value of r
    plt.scatter( r*x, r*y, c=theta_array, marker='.', cmap='hsv', alpha=1-r )
# end for 
plt.show()

### Disk with colormap by polar angle

In [None]:
#plane_idx = 0
num_circles = 40 # circles per plane
numpoints = 100 #points per circle
maxrad = 3

radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in \n the latent space in plane # {plane_idx} with $\lambda_c=${curv_w}''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        #print("theta:",theta)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        #print("x:",x_array)
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        #print("y",y_array)
        circle = torch.cat((x_array,y_array),dim=-1)
        #circle = design_circle(100, rad=(j+1)*2)
        #plot.scatter(x_array,y_array)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        #plt.scatter(circle_in_D[:,0],circle_in_D[:,1])
        encoded_circle = encoder(circle_in_D).detach()
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1])
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1],c=theta, marker='.', cmap='hsv', alpha=1-rad/maxrad)
        plt.grid(True)
        
        #print(rad)
    ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
    plt.show()

In [None]:
#plane_idx = 0
num_circles = 40 # circles per plane
numpoints = 100 #points per circle
maxrad = 3

radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in \n the latent space in plane # {plane_idx} with $\lambda_c=${curv_w}''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        encoded_circle = encoder(circle_in_D).detach()
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1])
        plt.grid(True)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
    plt.show()

In [None]:
#plane_idx = 0
num_circles = 30 # circles per plane
numpoints = 100 #points per circle
maxrad = 20

radius_array = maxrad*np.sqrt(np.linspace(0,1,num_circles))

for plane_idx in range(k):
    plt.title( f'''Embedding of {num_circles} circles of radius up to {maxrad} in each \n of the {k} planes in the latent space with $\lambda_c=${curv_w}''')
    for rad in radius_array:
        theta = torch.linspace(0.,2*torch.pi*(numpoints-1)/numpoints,
                           numpoints)
        x_array = rad*torch.cos(theta).unsqueeze(0).T
        y_array = rad*torch.sin(theta).unsqueeze(0).T
        circle = torch.cat((x_array,y_array),dim=-1)
        circle_in_D = torch.matmul(phi[plane_idx],circle.T).T+shifts[plane_idx].squeeze()
        encoded_circle = encoder(circle_in_D).detach()
        # color by rad
        plt.scatter(encoded_circle[:,0],encoded_circle[:,1])
        # colorby polar angle
        #plt.scatter(encoded_circle[:,0],encoded_circle[:,1],c=theta, marker='.', cmap='hsv', alpha=1-rad/maxrad)
        plt.grid(True)
    #ax = plt.gca()
    # to make equal axis scales
    #ax.set_aspect('equal') 
    #plt.figure(figsize=(8, 8))
plt.show()

# Histogram of eigenvalues

In [None]:
N = 1000
torch.manual_seed(0)
samples_over_latent_space = torch.rand(N,2)-1.0

#visualize samples
plot.scatter(samples_over_latent_space[:,0],samples_over_latent_space[:,1])

In [None]:
metric_at_samples = RR.metric_jacfwd_vmap(samples_over_latent_space,function=decoder).detach()
eigenvalues = torch.linalg.eigvalsh(metric_at_samples)

In [None]:
min_eigenvalues = eigenvalues[:,0]
max_eigenvalues = eigenvalues[:,1]
all_eigenvalues = torch.cat((min_eigenvalues,max_eigenvalues),dim = 0)

In [None]:
import math
plt.title(f"Histogram of eigenvalues of metric $G$ evaluated \n at {N} samples uniformly distributed \n over the latent space with penalty \n on curvature $\lambda_c=${curv_w}")
plt.hist(all_eigenvalues, bins=round(math.sqrt(N)))
plt.xlabel("All eigenvalues of metric")
plt.show()

In [None]:
import math
plt.title(f"Histogram of minimal eigenvalues of metric $G$ evaluated at {N} \n samples uniformly distributed over the latent space ")
plt.hist(min_eigenvalues, bins=round(math.sqrt(N)))
plt.xlabel("Min eigenvalues of metric")
plt.show()

In [None]:
plt.hist(max_eigenvalues, bins=round(math.sqrt(N)))
plt.title(f"Histogram of maximal eigenvalues of metric $G$ evaluated at {N} \n samples uniformly distributed over the latent space ")
plt.xlabel("Max eigenvalues of metric")
plt.show()