In [350]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import os
from tqdm import tqdm
import plotly.io as pio

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.manifold import TSNE
import plotly.express as px

In [351]:
data_dir = 'Dataset'
### With these commands the train and test datasets, respectively, are downloaded
### automatically and stored in the local "data_dir" directory.
train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.MNIST(data_dir, train=False, download=True)

In [352]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
])

# Set the train transform
train_dataset.transform = train_transform
# Set the test transform
test_dataset.transform = test_transform

In [353]:
m = len(train_dataset)
# m = 60,000

# random_split randomly split a dataset into non-overlapping new datasets of given lengths
# train (48,000 images), val split (12,000 images)
train_data, val_data = random_split(train_dataset, [int(m - m * 0.2), int(m * 0.2)])

batch_size = 256

# The dataloaders handle shuffling, batching, etc...
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## 1 Define Encoder and Decoder classes

In [354]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim, fc2_input_dim):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            # First convolutional layer
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            # nn.BatchNorm2d(8),
            nn.ReLU(True),
            # Second convolutional layer
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            # Third convolutional layer
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            # nn.BatchNorm2d(32),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)

        ### Linear section
        self.encoder_lin = nn.Sequential(
            # First linear layer
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            # Second linear layer
            nn.Linear(128, encoded_space_dim)
        )
        
    def forward(self, x):
        # Apply convolutions
        x = self.encoder_cnn(x)
        # Flatten
        x = self.flatten(x)
        # Apply linear layers
        x = self.encoder_lin(x)
        return x

In [355]:
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim, fc2_input_dim):
        super().__init__()

        ### Linear section
        self.decoder_lin = nn.Sequential(
            # First linear layer
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            # Second linear layer
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        ### Unflatten
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))

        ### Convolutional section
        self.decoder_conv = nn.Sequential(
            # First transposed convolution
            nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            # Second transposed convolution
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            # Third transposed convolution
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)
        )
        
    def forward(self, x):
        # Apply linear layers
        x = self.decoder_lin(x)
        # Unflatten
        x = self.unflatten(x)
        # Apply transposed convolutions
        x = self.decoder_conv(x)
        # Apply a sigmoid to force the output to be between 0 and 1 (valid pixel values)
        x = torch.sigmoid(x)
        return x

## 2 initialize models, loss and optimizer

In [356]:
### Set the random seed for reproducible results
torch.manual_seed(0)

### Initialize the two networks
d = 4

encoder = Encoder(encoded_space_dim=d, fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d, fc2_input_dim=128)

In [357]:
### Define the loss function
loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr = 0.001 # Learning rate


params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

optim = torch.optim.Adam(params_to_optimize, lr=lr)

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)
#model.to(device)

Selected device: cuda


Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=288, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  )
)

## 3 Train model

In [358]:
def cov_cpt(img):
    cov_out = torch.zeros_like(img)
    for ii in range(img.shape[0]):
        for jj in range(img.shape[1]):
            cov_out[ii][jj] = torch.cov(img[ii][jj])

    return cov_out # (256, 1, 28, 28)

In [359]:
### Training function
def train_epoch_den(encoder, decoder, device, dataloader, optimizer, val_lambda0=1, val_lambda=1, val_lambdaprime=1):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss1 = []
    train_loss2 = []
    train_loss3 = []
    train_loss = []

    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_batch = image_batch.to(device)
        # Encode data (z)
        encoded_data = encoder(image_batch)
        z_zeros = torch.zeros_like(encoded_data)
        # Decode data (x_hat)
        decoded_data = decoder(encoded_data)
        # Re-encode data (z_hat)
        z_hat = encoder(decoded_data)

        # Evaluate loss
        loss_fn = torch.nn.MSELoss()
        loss1 = loss_fn(image_batch, decoded_data)
        #loss1 = loss_fn(cov_cpt(image_batch), cov_cpt(decoded_data)) # Covariance
        loss2 = loss_fn(encoded_data, z_hat)
        loss3 = loss_fn(encoded_data, z_zeros)
        # loss1 = torch.mean((torch.linalg.matrix_norm(image_batch - decoded_data)) ** 2)
        # loss2 = (torch.linalg.matrix_norm(encoded_data - z_hat)) ** 2
        # loss3 = (torch.linalg.matrix_norm(encoded_data)) ** 2
        loss = val_lambda0 * loss1 + val_lambda * loss2 + val_lambdaprime * loss3

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print batch loss
        print('\t partial train loss (single batch): %f' % (loss.data))

        train_loss1.append(loss1.detach().cpu().numpy())
        train_loss2.append(loss2.detach().cpu().numpy())
        train_loss3.append(loss3.detach().cpu().numpy())
        train_loss.append(loss.detach().cpu().numpy())

    return np.mean(train_loss1), np.mean(train_loss2), np.mean(train_loss3), np.mean(train_loss)

In [360]:
### Testing function
def test_epoch_den(encoder, decoder, device, dataloader, loss_fn):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch, _ in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Encode data
            encoded_data = encoder(image_batch)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        val_loss = loss_fn(conc_out, conc_label)
    return val_loss.data

In [361]:
def plot_ae_outputs(encoder,decoder,n=5):
    plt.figure(figsize=(10,4.5))
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[i][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()

In [None]:
### Training cycle
num_epochs = 30
val_lambda0 = 1
val_lambda = 0.3
val_lambdaprime = 0.3
history_da={'train_loss':[],'val_loss':[]}
loss1 = []
loss2 = []
loss3 = []
loss = []

for epoch in range(num_epochs):
    print('EPOCH %d/%d' % (epoch + 1, num_epochs))
    ### Training (use the training function)
    train_loss1, train_loss2, train_loss3, train_loss=train_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=train_loader, 
        optimizer=optim,
        val_lambda0=val_lambda0,
        val_lambda=val_lambda,
        val_lambdaprime=val_lambdaprime)

    loss.append(train_loss)
    loss1.append(train_loss1)
    loss2.append(train_loss2)
    loss3.append(train_loss3)

    plt.plot(loss)
    plt.plot(loss1)
    plt.plot(loss2)
    plt.plot(loss3)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(labels=('loss', 'x-x_hat', 'z-z_hat', 'z'))
    plt.show()

    ### Validation
    val_loss = test_epoch_den(
        encoder=encoder, 
        decoder=decoder, 
        device=device, 
        dataloader=valid_loader, 
        loss_fn=loss_fn)
    # Print Validationloss
    history_da['train_loss'].append(train_loss)
    history_da['val_loss'].append(val_loss)
    print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))
    plot_ae_outputs(encoder,decoder)

In [None]:
test_epoch_den(encoder,decoder,device,test_loader,loss_fn).item()

plt.figure(figsize=(10,4.5))
n=5
for i in range(n):
    ax = plt.subplot(2,n,i+1)
    img = test_dataset[i][0].unsqueeze(0).to(device)
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        rec_img = decoder(encoder(img))
    plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    if i == n//2:
        ax.set_title('Original images')
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    if i == n//2:
        ax.set_title('Reconstructed images')
#plt.savefig('./C23_Rec.jpg')
plt.show()

In [None]:
plt.plot(loss, '--', linewidth=2.5)
plt.plot(loss1)
plt.plot(loss2)
plt.plot(loss3)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(labels=['loss', 'x-x_hat', 'z-z_hat', 'z'])
#plt.savefig('./C23.jpg')
plt.show()

In [None]:
plt.plot(loss2)
plt.plot(loss3)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(labels=['z-z_hat', 'z'])
#plt.savefig('./C23_z.jpg')
plt.show()