In [None]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd 
import random 
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

data_dir = 'dataset'

train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.MNIST(data_dir, train=False, download=True)

train_transform = transforms.Compose([
transforms.ToTensor(),
])

test_transform = transforms.Compose([
transforms.ToTensor(),
])

train_dataset.transform = train_transform
test_dataset.transform = test_transform

m=len(train_dataset)

train_data, val_data = random_split(train_dataset, [int(m-m*0.2), int(m*0.2)])
batch_size=256 #was 256

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,shuffle=True)

In [None]:
train_dataset

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        return x
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [None]:
### Define the loss function
#loss_fn = torch.nn.MSELoss()
#loss_fn = myloss2

### Define an optimizer (both for the encoder and the decoder!)
lr= 5e-4 #0.001

### Set the random seed for reproducible results
torch.manual_seed(0)

### Initialize the two networks
d = 2

#model = Autoencoder(encoded_space_dim=encoded_space_dim)
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-05)

# Check if the GPU is available
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#print(f'Selected device: {device}')

#Force CPU
device = torch.device("cpu")

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

## Functions for the computation of the integral of the absolute value of scalar curvature: Func

In [None]:
#Let us take a uniform grid on the latent space. Note that here d=2. The bounds for the grid can be taken from 3 sigma rule. 
#We will take 2 sigmas however
def makegrid(encoded_data, numsteps):
    latent = encoded_data
    latent = latent.detach().cpu()
    mean = latent.mean(dim=0)
    #print(mean)
    std = (latent - mean).pow(2).mean(dim=0).sqrt()
    #print(std)
    xs = torch.linspace(mean[0]-2*std[0], mean[0]+2*std[0], steps = numsteps)
    ys = torch.linspace(mean[1]-2*std[1], mean[1]+2*std[1], steps = numsteps)
    tgrid = torch.cartesian_prod(ys, xs)
    tgrid = tgrid.roll(1,1)
    return tgrid


In [None]:
#metric on a grid
def g(grid, decoded_grid):
    numsteps = int(np.sqrt(grid.shape[0]))
    
    hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
    hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)
    
    #latent = grid
    #latent = latent.to(device)
    #psi = decoder(latent)
    psi = decoded_grid

    psi_next_x =  psi.roll(-1,0)
    psi_prev_x =  psi.roll(1,0)
    psi_next_y =  psi.roll(-numsteps,0)
    psi_prev_y =  psi.roll(numsteps,0)
    
    dpsidx = (psi_next_x - psi_prev_x)/(2*hx)
    dpsidy = (psi_next_y - psi_prev_y)/(2*hy)
    
    metric = torch.cat(((dpsidx*dpsidx).sum((1,2,3)),(dpsidx*dpsidy).sum((1,2,3)),(dpsidx*dpsidy).sum((1,2,3)),(dpsidy*dpsidy).sum((1,2,3))),0)
    metric = metric.view(4, numsteps*numsteps)
    metric = metric.transpose(0, 1)
    metric = metric.view(numsteps*numsteps, 2, 2)
    return metric

In [None]:
#simultaneous differentiation on a grid with torch.roll
def diff_by_x(tensor, numsteps, h):
    psi = tensor
    psi_next_x =  psi.roll(-1,0)
    psi_prev_x =  psi.roll(1,0)
    dpsidx = (psi_next_x - psi_prev_x)/(2*h)
    return dpsidx
def diff_by_y(tensor, numsteps, h):
    psi = tensor
    psi_next_y =  psi.roll(-numsteps,0)
    psi_prev_y =  psi.roll(numsteps,0)
    dpsidy = (psi_next_y - psi_prev_y)/(2*h)
    return dpsidy
    

In [None]:
#derivatives of the metric on a grid
def dg_grid (grid, decoded_grid): #dg
    
    numsteps = int(np.sqrt(grid.shape[0]))
    
    hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
    hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)

    #latent = grid
    #latent = latent.to(device)
    psi = decoded_grid
    
    dpsidx = diff_by_x(psi, numsteps, hx)
    dpsidy = diff_by_x(psi, numsteps, hy)
    dpsidx_second = diff_by_x(dpsidx, numsteps, hx)
    dpsidx_dy = diff_by_y(dpsidx, numsteps, hy)
    dpsidy_second = diff_by_y(dpsidy, numsteps, hy)
    
    #metric = torch.cat(((dpsidx*dpsidx).sum((1,2,3)),(dpsidx*dpsidy).sum((1,2,3)),
    #                  (dpsidx*dpsidy).sum((1,2,3)),(dpsidy*dpsidy).sum((1,2,3))),0)
    
    dgdx = torch.cat((2*(dpsidx*dpsidx_second).sum((1,2,3)),(dpsidx_second * dpsidy + dpsidx * dpsidx_dy).sum((1,2,3)),
                      (dpsidx_second * dpsidy + dpsidx * dpsidx_dy).sum((1,2,3)),2*(dpsidy * dpsidx_dy).sum((1,2,3))),0)
    dgdy = torch.cat((2*(dpsidx*dpsidx_dy).sum((1,2,3)),(dpsidy_second * dpsidx + dpsidy * dpsidx_dy).sum((1,2,3)),
                      (dpsidy_second * dpsidx + dpsidy * dpsidx_dy).sum((1,2,3)),2*(dpsidy*dpsidy_second).sum((1,2,3))),0)
    metric_der = torch.cat((dgdx, dgdy), 0)
    metric = metric_der
    metric = metric.view(8, numsteps*numsteps)
    metric = metric.transpose(0, 1)
    metric = metric.view(numsteps*numsteps, 2, 4)
    metric = metric.view(numsteps*numsteps, 2, 2, 2)
    return metric

In [None]:
#Christoffel symbols on a grid
def Ch_grid(grid, metric_inv, metric_der):
    #x = grid[:,0]
    #y = grid[:, 1]
    n = grid.shape[0]
    Ch = torch.zeros((n, 2,2,2))
    for i in range(2):
        for j in range(2):
            for l in range(2):
                for k in range(2):
                    #Ch^l_ij
                    Ch[:,l,i,j] += 0.5 * metric_inv[:,l,k] * (metric_der[:,i,k,j] + metric_der[:,j,i,k] - metric_der[:,k,i,j])
    return Ch

Derivatives of Christoffel symbols on a grid

In [None]:
#derivatives of Christoffel symbols on a grid
def Ch_der_grid(grid, metric_inv, metric_der):
    n = grid.shape[0]

    numsteps = int(np.sqrt(grid.shape[0]))
    hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
    hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)
    
    Chdx = diff_by_x(Ch_grid(grid, metric_inv, metric_der), numsteps, hx)
    Chdy = diff_by_y(Ch_grid(grid, metric_inv, metric_der), numsteps, hy)
    Chder = torch.cat((Chdx, Chdy), -1)
    Chder = Chder.view(n,2,2,2,2)
    Chder = Chder.transpose(-1,-2)
    return Chder



In [None]:
# Riemann curvature tensor (3,1)
def Riem(grid, metric_inv, metric_der):
    n = grid.shape[0]
    Ch_der = Ch_der_grid(grid, metric_inv, metric_der)
    Ch = Ch_grid(grid, metric_inv, metric_der)

    Riem = torch.zeros(n, 2, 2, 2, 2)
    for i in range(2):
        for j in range(2):
            for k in range(2):
                for l in range(2):                    
                    Riem[:, i, j, k, l] = Ch_der[:, i, l, j, k] - Ch_der[:, i, k, j, l] 
                    for p in range(2):
                        Riem[:, i, j, k, l] += (Ch[:, i, k, p]*Ch[:, p, l, j] - Ch[:, i, l, p]*Ch[:, p, k, j])
    return Riem



In [None]:
# Ricci curvature tensor via Riemann
# R_ab = Riem^c_acb
# This function is written in a suboptimal way but we dl not use it here
def Ric(grid, metric_inv, metric_der):
    n = grid.shape[0]
    Ric = torch.zeros(n, 2, 2)
    for a in range(2):
        for b in range(2):
            for c in range(2):
                Ric[:, a, b] += Riem(grid, metric_inv, metric_der)[:, c, a, c, b]
    return Ric
    # takes 2.5 secs on 100 by 100 grid

In [None]:
# Scalar curvature tensor via Riemann and Ricci
# R_ab = Riem^c_acb
# R = g^ij * R_ij
def Sc(grid, metric_inv, metric_der):
    n = grid.shape[0]
    Riemann = Riem(grid, metric_inv, metric_der)

    Sc = torch.zeros(n)
    Ric = torch.zeros(n, 2, 2)
    for a in range(2):
        for b in range(2):
            for c in range(2):
                Ric[:, a, b] += Riemann[:, c, a, c, b]
    #einsum!!
    Sc = metric_inv*Ric
    Sc = torch.sum(Sc,(1,2))
    return Sc

In [None]:
#curvature measuring functional
    
def Func(encoded_data):
    
    numsteps = 30 # grid of size numsteps x numsteps
    grid = makegrid(encoded_data, numsteps)

    hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
    hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)

    #computeing metric and its derivatives on the grid
    latent = grid
    latent = latent.to(device)
    decoded_grid = decoder(latent)
    
    metric = g(grid, decoded_grid)
    metric_der = dg_grid(grid, decoded_grid)
    #with torch.no_grad():
        #metric = g(grid).cpu()
        #metric_der = dg_grid(grid).cpu()
    metric_inv = torch.inverse(metric.cpu()) #this is the inverse of the metric on the grid

    #Frobenius norm on the metric without border

    Newfrob = metric.norm(dim=(1,2)).view(numsteps,numsteps)
    Newfrob = Newfrob[2:-2,2:-2].transpose(0,1)

    Scalar_curvature_grid = Sc(grid, metric_inv, metric_der)

    Scalar_curv = Scalar_curvature_grid.view(numsteps,numsteps) #reshaping
    Scalar_curv = Scalar_curv[2:-2,2:-2].transpose(0,1) #avoiding border effects

    #F_simp = (abs(Scalar_curv)*hx*hy).sum() #integrating

    metric_no_border = metric.reshape(numsteps, numsteps,2,2)[2:-2,2:-2]
    det_metric_no_border = torch.det(metric_no_border.cpu())
    det_sqrt = torch.sqrt(det_metric_no_border)

    #F_new = (det_sqrt*torch.abs(Scalar_curv)*hx*hy).sum()
    #F_new = (det_sqrt*Newfrob*hx*hy).sum()

    F_new = (det_sqrt*torch.abs(Scalar_curv**2)*hx*hy).sum()

    return F_new

In [None]:
device

In [None]:
# building my loss function
def myloss(decoded_data, image_batch):
    oldloss = torch.nn.MSELoss()
    #curv_w = 0.0001 #initially 0.0001

    image_batch = image_batch.to(device)
    encoded_data = encoder(image_batch)

    newloss = Func(encoded_data)

    #myloss = oldloss(decoded_data, image_batch) + curv_w * newloss
    myloss = oldloss(decoded_data, image_batch)
    return myloss

### Define the loss function
loss_fn = myloss
#loss_fn = torch.nn.MSELoss()

In [None]:
### Training function
mseloss = torch.nn.MSELoss()

def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss = []
    mse_loss = []
    
    batch_idx = 0
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_batch = image_batch.to(device)
        # Encode data
        encoded_data = encoder(image_batch)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Evaluate loss

        #if batch_idx % 10 == 5:
        #    Curvature_functional = Func(encoded_data)
        #    loss = loss_fn(decoded_data, image_batch) + Curvature_functional
        #else:
        #    loss = loss_fn(decoded_data, image_batch)         
        #loss = loss_fn(decoded_data, image_batch) + 0.1*F(encoded_data) #changed!!

        #loss = loss_fn(decoded_data, image_batch) + 0.005*Func(encoded_data) #changed!!

        #loss = loss_fn(decoded_data, image_batch) + Func(encoded_data) #changed!!

        loss = loss_fn(decoded_data, image_batch)

        only_mse = mseloss(decoded_data, image_batch)

        new_loss = loss.data - only_mse.data
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        #print('\t partial train loss (single batch): %f' % (loss.data))
        #print(batch_idx)

        print('\t partial train loss (single batch): {:.6} \t curv_loss {:.6} \t mse {:.6}'.format(loss.data, new_loss, only_mse.data))
        
        train_loss.append(float(loss.detach().cpu().numpy()))
        mse_loss.append(float(only_mse.detach().cpu().numpy()))

        batch_idx += 1

    #return np.mean(train_loss), np.mean(mse_loss) 
    return train_loss, mse_loss 

In [None]:
# batches per epoch
len(train_loader)

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()   

In [None]:
#manifold plot
def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
def make_manifold_plot(encoded_data, size):
    mygrid = makegrid(encoded_data, size)
    latent1 = mygrid
    latent1 = latent1.to(device)
    result = decoder(latent1).cpu().detach()
    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(result[:100],10,5))
    plt.show()

In [None]:
num_epochs = 20
#diz_loss = {'train_loss':[],'val_loss':[]}
diz_loss = {'train_loss':[],'mse_loss':[]}

for epoch in range(num_epochs):
   #train_loss =train_epoch(encoder,decoder,device,train_loader,loss_fn,optim)
   #train_loss = train_epoch(encoder,decoder,device,train_loader,loss_fn,optim)[0]
   #mse_loss = train_epoch(encoder,decoder,device,train_loader,loss_fn,optim)[1]
   #train_loss, mse_loss = train_epoch(encoder,decoder,device,train_loader,loss_fn,optim)
   train_info = train_epoch(encoder,decoder,device,train_loader,loss_fn,optim)
   #val_loss = test_epoch(encoder,decoder,device,test_loader,loss_fn)

   with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs 
    images, labels = next(iter(test_loader))
    images = images.to(device)
    latent = encoder(images)
    latent = latent.cpu()
   #curv_func = Func(latent)
   train_loss = np.mean(train_info[0])
   curv_func = np.mean(train_info[0]) - np.mean(train_info[1]) # train_loss - mse_loss = curv_loss

   #print('\n EPOCH {}/{} \t train loss {} \t val loss {} \t curvature {}'.format(epoch + 1, num_epochs,train_loss,val_loss, curv_func))
   print('\n EPOCH {}/{} \t train loss {} \t Curvature {}'.format(epoch + 1, num_epochs,train_loss, curv_func))
   diz_loss['train_loss'].append(train_info[0])
   diz_loss['mse_loss'].append(train_info[1])
   #diz_loss['val_loss'].append(val_loss)
   plot_ae_outputs(encoder,decoder,n=10)
   make_manifold_plot(latent, 10)
diz_loss['train_loss'] = np.array(diz_loss['train_loss']).flatten()
diz_loss['mse_loss'] = np.array(diz_loss['mse_loss']).flatten()

In [None]:
# Plot losses per batch

plt.figure(figsize=(10,8))
plt.semilogy(diz_loss['train_loss'], label='Train_loss')
plt.semilogy(diz_loss['train_loss'] - diz_loss['mse_loss'], label='Curv_loss')
plt.title('Losses with weight 0.00001')
plt.xlabel('Batch')
plt.ylabel('Loss')
#plt.grid()
plt.legend()
#plt.title('loss')
plt.show()

In [None]:
#generate samples from latnt code and visualize them. It is not a latent space. Just some samples.
def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

encoder.eval()
decoder.eval()

with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs 
    images, labels = next(iter(test_loader))
    images = images.to(device)
    latent = encoder(images)
    latent = latent.cpu()

    mean = latent.mean(dim=0)
    print(mean)
    std = (latent - mean).pow(2).mean(dim=0).sqrt()
    print(std)

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, d)*std + mean
    #print(latent)
    #print(latent.shape)

    # reconstruct images from the random latent vectors
    latent = latent.to(device)
    img_recon = decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

In [None]:
#saving the model
PATH_enc = 'encoder_curw_w=0.0001_5epochs_30x30grid.pt'
torch.save(encoder.state_dict(), PATH_enc)
PATH_dec = 'decoder_curw_w=0.0001_5epochs_30x30grid.pt'
torch.save(decoder.state_dict(), PATH_dec)

# Point plot

In [None]:
!pip install tqdm

In [None]:
from tqdm import tqdm

In [None]:
encoded_samples = []
for sample in tqdm(test_dataset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples

In [None]:
import plotly.express as px

px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', 
           color=encoded_samples.label.astype(str), opacity=0.7)