# I. Train and test datasets

In [None]:
# Hyperparameters for dataset
D = 784       #dimension
k = 3         # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
shift_class = 0


# Hyperparameters for data loaders
batch_size  = 16
split_ratio = 0.2

# Set manual seed for reproducibility
# torch.manual_seed(0)

In [None]:
# Minimal imports
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

import ricci_regularization

# Generate dataset
train_dataset = ricci_regularization.generate_dataset(D, k, n, shift_class=shift_class)

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# test_data[:][0] will give the vectors of data without labels from the test part of the dataset

# II. Declaration of AE

In [None]:
# Check if the GPU is available
cuda_on = torch.cuda.is_available()
if cuda_on:
    device  = torch.device("cuda") 
else :
    device = torch.device("cpu")
print(f'Selected device: {device}')

In [None]:

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        #self.encoder = nn.Linear(input_dim, hidden_dim)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512, bias=True),
            nn.LeakyReLU(),
            nn.Linear(512, 256, bias=True),
            nn.LeakyReLU(),
            nn.Linear(256, 128, bias=True),
            nn.LeakyReLU(),
            nn.Linear(128, hidden_dim, bias=True),
        )
        self.kl = 0 # For compatibility
        
    def forward(self, x):
        out = self.encoder(x)
        # out = torch.relu(out)
        #out = torch.sin(out)
        #out = torch.sigmoid(out)
        #out = F.leaky_relu(out)
        return out

class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        #self.decoder = nn.Linear(hidden_dim, output_dim)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, 128, bias=True),
            nn.LeakyReLU(),
            nn.Linear(128, 256, bias=True),
            nn.LeakyReLU(),
            nn.Linear(256, 512, bias=True),
            nn.LeakyReLU(),
            nn.Linear(512, output_dim, bias=True),
            nn.LeakyReLU()
        )
    def forward(self, x):
        out = self.decoder(x)
        #out = torch.sigmoid(out)
        return out

In [None]:
### Define the loss function
loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr         = 2e-5
momentum   = 0.8
num_epochs = 5
batches_per_plot = 50

### Set the random seed for reproducible results
# torch.manual_seed(0)

### Initialize the two networks
d = 2

In [None]:
# Classical
#model = Autoencoder(hidden_dim=hidden_dim)
encoder = Encoder(input_dim=784, hidden_dim=d)
decoder = Decoder(hidden_dim=d, output_dim=784)

# VAE
#encoder = VariationalEncoder(input_dim=784, hidden_dim=d, cuda=cuda_on)
#decoder = Decoder(hidden_dim=d, output_dim=784)

params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optimizer = torch.optim.RMSprop(params_to_optimize, lr=lr, momentum=momentum, weight_decay=0.0)

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

# III. Training

In [None]:
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)

In [None]:
def point_plot(encoder, data, batch_idx):

    labels = data[:][1]
    data   = data[:][0]

    # Encode
    encoder.eval()
    with torch.no_grad():
        data = data.view(-1,28*28) # reshape the img
        data = data.to(device)
        encoded_data = encoder(data)

    # Record codes
    latent = encoded_data.cpu().numpy()
    labels = labels.numpy()

    #Plot
    plt.figure(figsize=(8, 6))
    plt.scatter( latent[:,0], latent[:,1], c=labels, alpha=0.5, marker='o', edgecolor='none', cmap=discrete_cmap(k, 'jet'))
    plt.title( f'''Latent space for test data in AE at batch {batch_idx}''')
    plt.colorbar(ticks=range(k))
    axes = plt.gca()
    plt.grid(True)
    
    return plt

In [None]:
# Batches per epoch
print( "Reality check of batch splitting: ")
print( "-- Batches per epoch", len(train_loader) )
print( "batch size:", batch_size )
print( "product: ", len(train_loader)*batch_size )
print( "-- To be compared to:", (1.0-split_ratio)*n*k)


In [None]:
#diz_loss = {'train_loss':[],'mse_loss':[]}
diz_loss = {'train_loss':[]}

for epoch in range(num_epochs):

   # Set train mode for both the encoder and the decoder
   encoder.train()
   decoder.train()
   mse_loss = []
   
   batch_idx = 0
   # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
   for image_batch, _ in train_loader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
      #shaping the images properly
      image_batch = image_batch.view(-1,28*28)
      # Move tensor to the proper device
      image_batch = image_batch.to(device)
      # True batch size
      true_batch_size = image_batch.shape[0]

      optimizer.zero_grad()
      
      # Front-propagation
      # -- Encode data
      encoded_data = encoder(image_batch)
      # -- Decode data
      decoded_data = decoder(encoded_data)
      # --Evaluate loss
      #loss = loss_fn(decoded_data, image_batch)
      loss = torch.sum( (decoded_data-image_batch)**2 )/true_batch_size

      # Backward pass
      loss.backward()
      optimizer.step()
      # Print batch loss
      print('\t Partial train loss (single batch): %f' % (loss.data))
    
      #print('\t partial train loss (single batch): {:.6} \t curv_loss {:.6} \t mse {:.6}'.format(loss.data, new_loss, only_mse.data))
      
      mse_loss.append(float(loss.detach().cpu().numpy()))

      # Plot      
      if (batch_idx % batches_per_plot == 0):
            plot = point_plot(encoder, test_data, batch_idx)
            plot.show()
       # end if

      batch_idx += 1
   # end for
   
   train_info = mse_loss
   train_loss = np.mean(train_info)
   
   print('\n EPOCH {}/{} \t train loss {}'.format(epoch + 1, num_epochs, train_loss))
   #plot = point_plot(test_data.cpu())
   #plot.show()
   
   diz_loss['train_loss'].append(train_info)
   
diz_loss['train_loss'] = np.array(diz_loss['train_loss']).flatten()
#diz_loss['mse_loss'] = np.array(diz_loss['mse_loss']).flatten()

In [None]:
# Plot losses

plt.figure(figsize=(10,8))
plt.plot(diz_loss['train_loss'], label='Train_loss')
#plt.semilogy(diz_loss['train_loss'] - diz_loss['mse_loss'], label='Curv_loss')
plt.title('Loss')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.grid()
plt.legend()
plt.show()

# IV. Plotting

In [None]:
from tqdm import tqdm

print( "Computing latent variables for train dataset" )

inference_batch_size    = 1024
inference_train_loader  = torch.utils.data.DataLoader( train_data , batch_size=inference_batch_size)

# Fetch data for plot 
latent_variables = []
label_variables  = []
batch_idx = 0
for (data, labels) in tqdm( inference_train_loader, position=0 ):
    batch_idx += 1
    data = data.unsqueeze(0).to(device)
    # Encode
    encoder.eval()
    with torch.no_grad():
        data = data.view(-1,28*28) # reshape the img
        encoded_data = encoder(data)
    # Record codes
    encoded_data = encoded_data.cpu().numpy()
    labels = labels.numpy()
    latent_variables.extend( encoded_data )
    label_variables.extend( labels)
#
print("Reality check:")
print(latent_variables[:2])

In [None]:
code_array = np.array( latent_variables )

#Plot
plt.figure(figsize=(8, 6))
plt.scatter( code_array[:,0], code_array[:,1], c=label_variables, alpha=0.5, marker='o', edgecolor='none', cmap=discrete_cmap(k, 'jet'))
plt.title( "Latent space for train data in AE")
plt.colorbar(ticks=range(k))
axes = plt.gca()
plt.grid(True)
plt.show()

In [None]:
print( "Computing latent variables for test dataset" )

inference_batch_size   = 1024
inference_test_loader  = torch.utils.data.DataLoader(test_data , batch_size=inference_batch_size)

# Fetch data for plot 
latent_variables = []
label_variables  = []
batch_idx = 0
for (data, labels) in tqdm( inference_test_loader, position=0 ):
    batch_idx += 1
    data = data.unsqueeze(0).to(device)
    # Encode
    encoder.eval()
    with torch.no_grad():
        data = data.view(-1,28*28) # reshape the img
        encoded_data = encoder(data)
    # Record codes
    encoded_data = encoded_data.cpu().numpy()
    labels = labels.numpy()
    latent_variables.extend( encoded_data )
    label_variables.extend( labels)
#
print("Reality check:")
print(latent_variables[:2])



In [None]:
code_array = np.array( latent_variables )

#Plot
plt.figure(figsize=(8, 6))
plt.scatter( code_array[:,0], code_array[:,1], c=label_variables, alpha=0.5, marker='o', edgecolor='none', cmap=discrete_cmap(k, 'jet'))
plt.title( "Latent space for test data in AE")
plt.colorbar(ticks=range(k))
axes = plt.gca()
plt.grid(True)
plt.show()

## IV. Plotting with plotly

In [None]:
encoded_samples     = []
encoded_samples_raw = []
#for sample in tqdm(train_dataset):
for sample in tqdm(test_data):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    encoder.eval()
    with torch.no_grad():
        img = img.view(-1,28*28) # reshape the img
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_sample_raw = np.array( [encoded_img[0], encoded_img[1], label] )
    encoded_samples.append(encoded_sample)
    encoded_samples_raw.append( encoded_sample_raw )

In [None]:
import pandas as pd

encoded_samples_df = pd.DataFrame(encoded_samples)
encoded_samples_df

import plotly.express as px

px.scatter(encoded_samples_df, x='Enc. Variable 0', y='Enc. Variable 1', 
           color=encoded_samples_df.label.astype(str), opacity=0.5)