# HOMEWORK 2

## Autoencoders

In [None]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd # this module is useful to work with tabular data
import random # this module will be used to select random samples from a collection
import os # this module will be used just to create directories in the local filesystem
from tqdm import tqdm # this module is useful to plot progress bars

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch import nn
from sklearn.decomposition import PCA
import plotly.express as px

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
trained = True
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [None]:
# Check if the GPU is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

### FashionMNIST Dataset 

In [None]:
### Download the data and create dataset
data_dir = 'classifier_data'
# With these commands the train and test datasets, respectively, are downloaded 
# automatically and stored in the local "data_dir" directory.
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.FashionMNIST(data_dir, train=False, download=True)

In [None]:
### Plot some sample
label_names=['t-shirt/top','trouser','pullover','dress','coat','sandal','shirt',
             'sneaker','bag','ankle boot']
fig, axs = plt.subplots(5, 5, figsize=(8,8))
for ax in axs.flatten():
    # random.choice allows to randomly sample from a list-like object (basically anything that can be accessed with an index, like our dataset)
    img, label = random.choice(train_dataset)
    ax.imshow(np.array(img), cmap='gist_gray')
    ax.set_title(f'Label: {label_names[label]} [{label}]')
    ax.set_xticks([])
    ax.set_yticks([])
plt.tight_layout()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),])

In [None]:
# Chek how looks like the dataset before transform
print(train_dataset[0][0])

# Set the train transform
train_dataset.transform = transform
# Set the test transform
test_dataset.transform = transform

# Chek how looks like the dataset after transform
print(train_dataset[0][0])


### Train/Validation/Test Split

In [None]:
batch_size = 256
validation_split = .2

# Creating data indices for training and validation splits:
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

# Define train dataloader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
# Define validation dataloader
val_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler, shuffle = False)
# Define test dataloader
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle = False)

batch_data, batch_labels = next(iter(train_dataloader))
print(f"TRAIN BATCH SHAPE")
print(f"\t Data: {batch_data.shape}")
print(f"\t Labels: {batch_labels.shape}")

batch_data, batch_labels = next(iter(val_dataloader))
print(f"TRAIN BATCH SHAPE")
print(f"\t Data: {batch_data.shape}")
print(f"\t Labels: {batch_labels.shape}")

batch_data, batch_labels = next(iter(test_dataloader))
print(f"TEST BATCH SHAPE")
print(f"\t Data: {batch_data.shape}")
print(f"\t Labels: {batch_labels.shape}")

### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            # First convolutional layer   [16x15x15]
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=2, padding=2),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            
            # Second convolutional layer  [32x6x6]
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(32),
            
            # Third convolutional layer   [64x3x3]
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1) #1 to exclude the batch size

        ### Linear section
        self.encoder_lin = nn.Sequential(
            # First linear layer
            nn.Linear(in_features=64*3*3, out_features=64),
            nn.ReLU(True),
            # Second linear layer
            nn.Linear(in_features=64, out_features=encoded_space_dim)
        )
        
    def forward(self, x):
        # Apply convolutions
        x = self.encoder_cnn(x)
        # Flatten
        x = self.flatten(x)
        # # Apply linear layers
        x = self.encoder_lin(x)
        return x

### Decoder

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()

        ### Linear section
        self.decoder_lin = nn.Sequential(
            # First linear layer
            nn.Linear(in_features = encoded_space_dim, out_features=64),
            nn.ReLU(True),
            # Second linear layer
            nn.Linear(in_features=64, out_features=64*3*3),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(64, 3, 3))

        ### Convolutional section
        self.decoder_conv = nn.Sequential(
            # First transposed convolution  [32x6x6]
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(32),
            
            # Second transposed convolution  [16x14x14]
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=5, stride=2, padding=1, output_padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            
            # Third transposed convolution   [1x28x28]
            nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=5, stride=2, padding=2, output_padding=1)
        )
        
    def forward(self, x):
        # Apply linear layers
        x = self.decoder_lin(x)
        # Unflatten
        x = self.unflatten(x)
        # Apply transposed convolutions
        x = self.decoder_conv(x)
        # Apply a sigmoid to force the output to be between 0 and 1 (valid pixel values)
        x = torch.sigmoid(x)
        return x

In [None]:
### Initialize the two networks
encoded_space_dim = 20
encoder = Encoder(encoded_space_dim=encoded_space_dim)
decoder = Decoder(encoded_space_dim=encoded_space_dim)

In [None]:
### Some examples
# Take an input image 
img, _ = test_dataset[0]
img = img.unsqueeze(0) # Add the batch dimension in the first axis
print('Original image shape:', img.shape)
# Encode the image
img_enc = encoder(img)
print('Encoded image shape:', img_enc.shape)
# Decode the image
dec_img = decoder(img_enc)
print('Decoded image shape:', dec_img.shape)

## Training

In [None]:
### Define the loss function
loss_fn = nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr = 0.01 # Learning rate
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]
#optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-5)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)

optim = torch.optim.SGD(params_to_optimize, lr=lr)

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

### Training function

In [None]:
### Training function
def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer, scheduler, printer=False):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss = []
    # Iterate the dataloader
    for image_batch, _ in dataloader:
        # Move tensor to the proper device
        image_batch = image_batch.to(device)
        # Encode data
        encoded_data = encoder(image_batch)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Evaluate loss
        loss = loss_fn(decoded_data, image_batch)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
    if scheduler != None:
        scheduler.step()
    # Print train loss
    if(printer):
        print('\t train loss: %f' % (np.mean(train_loss)))
    return np.mean(train_loss)

### Test function

In [None]:
### Testing function
def test_epoch(encoder, decoder, device, dataloader, loss_fn):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch, _ in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Encode data
            encoded_data = encoder(image_batch)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        test_loss = loss_fn(conc_out, conc_label)
    return test_loss.data

In [None]:
if not trained:
    ### Training cycle
    train_loss_log = []
    val_loss_log = []

    num_epochs = 10
    for epoch in range(num_epochs):
        print('EPOCH %d/%d' % (epoch + 1, num_epochs))
        ### Training (use the training function)
        train_loss = train_epoch(
            encoder=encoder, 
            decoder=decoder, 
            device=device, 
            dataloader=train_dataloader, 
            loss_fn=loss_fn, 
            optimizer=optim,
            scheduler=None, 
            printer=True)
        train_loss_log.append(train_loss)
        ### Validation  (use the testing function)
        val_loss = test_epoch(
            encoder=encoder, 
            decoder=decoder, 
            device=device, 
            dataloader=test_dataloader, 
            loss_fn=loss_fn)
        val_loss_log.append(val_loss)
        # Print Validationloss
        print('\n\n\t VALIDATION - EPOCH %d/%d - loss: %f\n\n' % (epoch + 1, num_epochs, val_loss))

        ### Plot progress
        # Get the output of a specific image (the test image at index 0 in this case)
        img = test_dataset[0][0].unsqueeze(0).to(device)
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            rec_img  = decoder(encoder(img))
        # Plot the reconstructed image
        fig, axs = plt.subplots(1, 2, figsize=(12,6))
        axs[0].imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[0].set_title('Original image')
        axs[1].imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[1].set_title('Reconstructed image (EPOCH %d)' % (epoch + 1))
        plt.tight_layout()
        plt.pause(0.1)
        # Save figures
        os.makedirs('autoencoder_progress_%d_features' % encoded_space_dim, exist_ok=True)
        fig.savefig('autoencoder_progress_%d_features/epoch_%d.jpg' % (encoded_space_dim, epoch + 1))
        plt.show()
        plt.close()

        # Save network parameters
        torch.save(encoder.state_dict(), 'encoder_params_simple.pth')
        torch.save(decoder.state_dict(), 'decoder_params_simple.pth')
else:
    # Load network parameters
    encoder_state_dict = torch.load('encoder_params_simple.pth', map_location=device)
    decoder_state_dict = torch.load('decoder_params_simple.pth', map_location=device)
    
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)

In [None]:
## Some examples of reconstructed images
indices = np.random.randint(len(test_dataset), size=8)
subset = torch.utils.data.Subset(test_dataset, indices)
testloader_subset = DataLoader(subset, batch_size=1, num_workers=0, shuffle=False)
    
encoder.eval()
decoder.eval()
fig, axs = plt.subplots(2, 4, figsize=(15,15))
axs = axs.flatten()
ax_n = 0
## Iterate trough the samples in the test dataset
iterator = iter(testloader_subset)
loop = True
while loop:
    try:
        data, label = next(iterator)
    except StopIteration:
        loop = False
    else:
        with torch.no_grad():
            latent = encoder(data.to(device))
            out = decoder(latent)
        axs[ax_n].set_xticks([])
        axs[ax_n].set_yticks([])
        axs[ax_n].imshow(out.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[ax_n].set_title(label_names[label[0].item()], fontsize = 18)
        ax_n += 1
        plt.tight_layout()
#plt.savefig("simple_reconstruction_examples.pdf", format='pdf')

### Optimization

In [None]:
if not trained:
        ###########OPTIMIZATION###########
        ! pip install optuna
        import optuna
        from optuna.integration import PyTorchLightningPruningCallback
        EPOCHS = 10

        def objective(trial):

            encoded_space_dim = 20
            encoder = Encoder(encoded_space_dim=encoded_space_dim).to(device)
            decoder = Decoder(encoded_space_dim=encoded_space_dim).to(device)

            parameters = [{'params': encoder.parameters()}, {'params': decoder.parameters()}]

            # Type of optimizer algorithm
            optim_name = ['Adam', 'SGD', 'Adagrad', 'RMSprop']
            optim_algorithm = trial.suggest_categorical("optimizer", optim_name)

            # Weight decay 
            w_decay = trial.suggest_float("weight_decay", 0, 1e-6)

            # Scheduling factor
            gamma = trial.suggest_float("gamma", 0, 1)

            optimizer = getattr(torch.optim, optim_algorithm)(parameters, lr = 0.01, weight_decay=w_decay)
            scheduler.gamma = gamma

            loss_func = nn.MSELoss()

            for epoch in range(EPOCHS):
                loss = train_epoch(encoder, decoder, device, train_dataloader, loss_fn, optimizer, scheduler)
                test_epoch(encoder, decoder, device, val_dataloader, loss_fn)

                trial.report(loss, epoch)

                # Handle pruning based on the intermediate value.
                if trial.should_prune():
                    raise optuna.exceptions.TrialPruned()

            return loss

        pruner: optuna.pruners.BasePruner = optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=0, interval_steps=1)

        study = optuna.create_study(study_name="Encoder", direction="minimize", pruner=pruner)
        study.optimize(objective, n_trials=50)

        print("Number of finished trials: {}".format(len(study.trials)))

        print("Best trial:")
        trial = study.best_trial

        print("  Value: {}".format(trial.value))

        print("  Params: ")
        for key, value in trial.params.items():
            print("    {}: {}".format(key, value))


In [None]:
print('#############')
print("BEST TRIAL")
print('#############')
print("\nParameters:")
print("weight_decay: 2.9505970645836693e-05")
print("gamma: 0.495944901064486")
print("optimizer: Adagrad")

In [None]:
### Initialize the two networks
encoded_space_dim = 20
encoder = Encoder(encoded_space_dim=encoded_space_dim).to(device)
decoder = Decoder(encoded_space_dim=encoded_space_dim).to(device)

### Define an optimizer (both for the encoder and the decoder!)
lr = 0.01 # Learning rate
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

In [None]:
if not trained:
    ### Training cycle
    train_loss_log = []
    val_loss_log = []
    lr_log = []
    best_loss = np.infty

    optim = torch.optim.Adagrad(params_to_optimize, lr=lr, weight_decay=2.9505970645836693e-05)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.495944901064486)

    num_epochs = 50
    patience = 3
    for epoch in range(num_epochs):
        print('EPOCH %d/%d' % (epoch + 1, num_epochs))
        ### Training (use the training function)
        train_loss = train_epoch(
            encoder=encoder, 
            decoder=decoder, 
            device=device, 
            dataloader=train_dataloader, 
            loss_fn=loss_fn, 
            optimizer=optim,
            scheduler=scheduler)
        train_loss_log.append(train_loss)
        lr_log.append(scheduler.get_last_lr())
        ### Validation  (use the testing function)
        val_loss = test_epoch(
            encoder=encoder, 
            decoder=decoder, 
            device=device, 
            dataloader=test_dataloader, 
            loss_fn=loss_fn)
        val_loss_log.append(val_loss)
        # Print Validationloss
        print('\n\n\t VALIDATION - EPOCH %d/%d - loss: %f\n\n' % (epoch + 1, num_epochs, val_loss))

        ### Plot progress
        # Get the output of a specific image (the test image at index 0 in this case)
        img = test_dataset[0][0].unsqueeze(0).to(device)
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            rec_img  = decoder(encoder(img))
        # Plot the reconstructed image
        fig, axs = plt.subplots(1, 2, figsize=(12,6))
        axs[0].imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[0].set_title('Original image')
        axs[1].imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[1].set_title('Reconstructed image (EPOCH %d)' % (epoch + 1))
        plt.tight_layout()
        plt.pause(0.1)
        # Save figures
        os.makedirs('autoencoder_progress_%d_features' % encoded_space_dim, exist_ok=True)
        fig.savefig('autoencoder_progress_%d_features/epoch_%d.jpg' % (encoded_space_dim, epoch + 1))
        plt.show()
        plt.close()

        # Save network parameters
        torch.save(encoder.state_dict(), 'encoder_params.pth')
        torch.save(decoder.state_dict(), 'decoder_params.pth')

        # Implement early stopping
        if(val_loss_log[-1] < best_loss):
            best_loss = val_loss_log[-1]
            patience = 3
        else:
            patience -= 1
            if(patience == 0): 
                print("#################\nLearning stopped because the validation error was not improving\n#################")
                break
else:
    # Load network parameters
    encoder_state_dict = torch.load('encoder_params.pth', map_location=device)
    decoder_state_dict = torch.load('decoder_params.pth', map_location=device)
    
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)

In [None]:
if not trained:
    ## plot the reconstruction loss
    plt.figure(figsize=(12,8))
    plt.semilogy(train_loss_log, label='Train loss')
    plt.semilogy(val_loss_log, label='Validation loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid()
    plt.legend()
    plt.show()

In [None]:
if not trained:  
    plt.figure(figsize=(12,8))
    plt.plot(lr_log)
    plt.xlabel('Epoch')
    plt.ylabel('Learning rate')
    plt.grid()
    plt.show()

## Network Analysis

In [None]:
## Some examples of reconstructed images
indices = np.random.randint(len(test_dataset), size=8)
subset = torch.utils.data.Subset(test_dataset, indices)
testloader_subset = DataLoader(subset, batch_size=1, num_workers=0, shuffle=False)
    
encoder.eval()
decoder.eval()
fig, axs = plt.subplots(2, 4, figsize=(15,15))
axs = axs.flatten()
ax_n = 0
## Iterate trough the samples in the test dataset
iterator = iter(testloader_subset)
loop = True
while loop:
    try:
        data, label = next(iterator)
    except StopIteration:
        loop = False
    else:
        with torch.no_grad():
            latent = encoder(data.to(device))
            out = decoder(latent)
        axs[ax_n].set_xticks([])
        axs[ax_n].set_yticks([])
        axs[ax_n].imshow(out.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[ax_n].set_title(label_names[label[0].item()])
        ax_n += 1
        plt.tight_layout()
#plt.savefig("Examples_adv.pdf", format='pdf')

### Latent space analysis-PCA

In [None]:
### Get the encoded representation of the test samples
encoded_samples = []
for sample in tqdm(test_dataset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)

encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples

In [None]:
#encoded_samples = encoded_samples.drop('label', axis=1) # this to remove the last column (label)
pca = PCA(n_components=2).fit(encoded_samples[:-2])
encoded_samples_pca = pca.transform(encoded_samples)

In [None]:
# This function creates a dictionary in which a key is the numerical label and the corresponding value is the 
# verbose label
def dictionary_classes(label_names):
    dic = {}
    i = 0
    for label in label_names:
        dic[str(i)] = label
        i += 1
    return dic
dictionary = dictionary_classes(label_names)

In [None]:
fig = px.scatter(encoded_samples_pca, x=0, y=1,
                 color=encoded_samples.label.astype(str),
                 labels={'0': 'Feature 1', '1': 'Feature 2', 'Color': 'Class'})

fig.for_each_trace(lambda t: t.update(name = dictionary[t.name]))
#fig.write_image("PCA.pdf")

### Generate new examples

In [None]:
# Based on the pca find two internal representation
t_shirt = np.array([-6, 1]) 
ankle_boot = np.array([6, -2])
sample1 = torch.tensor(np.dot(t_shirt, pca.components_) + pca.mean_)[:-1]
sample2 = torch.tensor(np.dot(ankle_boot, pca.components_) + pca.mean_)[:-1]
samples = [sample1, sample2]
encoder.eval()
decoder.eval()
encoder.float()
decoder.float()
fig, axs = plt.subplots(1, 2, figsize=(15,15))
axs = axs.flatten()
ax_n = 0
while ax_n < 2:
    with torch.no_grad():
        out = decoder(samples[ax_n].float().unsqueeze(0).to(device)) # add the batch dimension
    axs[ax_n].set_xticks([])
    axs[ax_n].set_yticks([])
    axs[ax_n].imshow(out.cpu().squeeze().numpy(), cmap='gist_gray')
    ax_n += 1
    plt.tight_layout()
#plt.savefig("new_examples.pdf", format='pdf')

## Autoencoder fine tuning

In [None]:
class Classifier(nn.Module):
    
    def __init__(self, encoder_output):
        
        super(Classifier,self).__init__()
        self.linear = nn.Sequential(nn.Linear(in_features=encoder_output, out_features=512),
                                    nn.ReLU(True),
                                    nn.BatchNorm1d(512),
                                    nn.Linear(in_features=512, out_features=128),
                                    nn.ReLU(True),
                                    nn.BatchNorm1d(128),
                                    nn.Linear(in_features=128, out_features=10))
    def forward(self, x):
        out = self.linear(x)
        return out

In [None]:
classifier = Classifier(encoder_output=encoded_space_dim)
classifier.to(device)
encoder = Encoder(encoded_space_dim).to(device)
encoder.load_state_dict(torch.load('encoder_params.pth', map_location=device))
loss_function = nn.CrossEntropyLoss()

In [None]:
# Define an optimizer
lr = 1e-4
optim = torch.optim.Adam(classifier.parameters(), lr=lr, weight_decay=5e-4)

### Training and testing functions

In [None]:
## WE NEED TO UPDATE TRAIN AND TEST FUNCTIONS IN ORDER TO USE THE LABELS THAT 
## WERE IGNORED IN THE PREVIOUS TWO FUNCTIONS

def training_step(encoder, classifier, train_loader, loss_fn, optimizer, train_loss_log, printer=True):
        
    classifier.train()
    train_loss = []
    train_correct = 0
    for sample_batched in train_loader:
        
        # Move data to device
        x_batch = sample_batched[0].to(device)
        label_batch = sample_batched[1].to(device)
        
        # Forward pass
        x = encoder(x_batch)
        out = classifier(x)

        # Compute loss
        loss = loss_fn(out, label_batch)

        # Backpropagation
        classifier.zero_grad()
        loss.backward()

        # Update the weights
        optimizer.step()

        # Save train loss for this batch
        loss_batch = loss.detach().cpu().numpy()
        train_loss.append(loss_batch) 
        
        scores, predictions = torch.max(out.data, 1)
        train_correct += (predictions == label_batch).sum().item()
        
    # Save average train loss over the batches
    train_loss = np.mean(train_loss)
    if(printer): print(f"AVERAGE TRAIN LOSS: {train_loss}")
    if(printer): print(f"TRAINING ACCURACY: {train_correct*100/len(train_loader.sampler)}")
    train_loss_log.append(train_loss)
    
def validation_step(encoder, classifier, val_loader, loss_fn, val_loss_log, printer = True):

    val_loss = []
    val_correct = 0
    classifier.eval() #evaluation mode
    with torch.no_grad():
        for sample_batched in val_loader:
            x_batch = sample_batched[0].to(device)
            label_batch = sample_batched[1].to(device)
            
            # Predict using the current model
            x = encoder(x_batch)
            y_pred = classifier(x)
            
            # Compute and save the val_loss for this batch 
            loss_batch = loss_fn(y_pred, label_batch).detach().cpu().numpy()
            val_loss.append(loss_batch)
            
            # Accuracy for this batch
            scores, predictions = torch.max(y_pred.data, 1)
            val_correct += (predictions == label_batch).sum().item()
            
        # Save average train loss over the batches
        val_loss = np.mean(val_loss)
        if(printer): print(f"AVERAGE VALIDATION LOSS: {val_loss}")
        if(printer): print(f"VALIDATION ACCURACY: {val_correct*100/len(val_loader.sampler)}")   
        val_loss_log.append(val_loss)

In [None]:
if not trained:
    num_epochs = 50
    train_loss_log = []
    validation_loss_log = []
    best_loss = np.infty
    patience = 3
    for i in range(num_epochs):
        print('#################')
        print(f'# EPOCH {i}')
        print('#################')
        #Train pass
        training_step(encoder, classifier, train_dataloader, loss_function, optim, train_loss_log, printer=True)
        #Validation pass
        validation_step(encoder, classifier, val_dataloader, loss_function, validation_loss_log, printer = True)

        # Implement early stopping
        if(validation_loss_log[-1] < best_loss):
            best_loss = validation_loss_log[-1]
            patience = 3
        else:
            patience -= 1
            if(patience == 0): 
                print("#################\nLearning stopped because the validation error was not improving\n#################")
                break 
else:
    classifier.load_state_dict(torch.load('classifier.torch', map_location=device))

In [None]:
def test_step(encoder, model, dataloader, loss_fn):
    encoder.cpu().eval()
    model.cpu().eval()
    with torch.no_grad():
        labels, outputs = [], []
        for image_batch, label_batch in tqdm(dataloader):
            out = encoder(image_batch)
            x_hat = model(out)
            labels.append(label_batch)
            outputs.append(x_hat)
        labels = torch.cat(labels)
        outputs = torch.cat(outputs)
        error = loss_fn(outputs, labels)
    return outputs, labels, error.data
    
test_outputs, test_labels, test_loss = test_step(encoder=encoder,
    model=classifier,
    dataloader=test_dataloader, 
    loss_fn=nn.CrossEntropyLoss())

# Compute accuracy
accuracy = 0
_, predictions = torch.max(test_outputs.data, 1)
accuracy += (predictions == test_labels).sum().item()
accuracy = accuracy/len(test_dataloader.sampler)*100
# Print Test loss
print(f"\n\nTEST LOSS : {test_loss}")
# Print accuracy
print(f"\nTEST ACCURACY : {accuracy}")

In [None]:
from sklearn import metrics
import seaborn as sn
# Predicted labels
y_true = test_labels.cpu().data.numpy()
y_pred = test_outputs.cpu().argmax(dim=1).numpy()
cm = metrics.confusion_matrix(y_true, y_pred)

# Convert confusion matrices to pandas data frames
CM_df = pd.DataFrame(cm)

# Plot confusion matrices
fig, ax = plt.subplots(figsize=(8,6))
ax = sn.heatmap(CM_df, annot=True, cmap='rocket_r', vmax=450, fmt='d')
ax.set_xlabel("Predicted label", fontsize=12)
ax.set_ylabel("True label", fontsize = 12)
#plt.savefig("confusion_matrix.pdf", format='pdf', bbox_inches = 'tight')

## Variational autoencoder

In [None]:
####  VARIATIONAL AUTOENCODER ####
class VariationalEncoder(nn.Module):
    def __init__(self, encoded_space_dim):  
        super(VariationalEncoder, self).__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            # First convolutional layer   [16x15x15]
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=2, padding=2),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            
            # Second convolutional layer  [32x6x6]
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(32),
            
            # Third convolutional layer   [64x3x3]
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64)
        )
        

        self.linear1 = nn.Linear(3*3*64, 256)
        self.linear2 = nn.Linear(256, 128)
        self.mu = nn.Linear(128, encoded_space_dim)  #mean latent vector
        self.var = nn.Linear(128, encoded_space_dim)  #var latent vector
        
        # Activation function
        self.act = nn.ReLU(True)

        # Flatten layer after the convolutions
        self.flatten = nn.Flatten(start_dim=1)
        
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))
            
        
    def forward(self, x):
        
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.act(self.linear1(x))
        x = self.act(self.linear2(x))
        mu =  self.mu(x)
        var = self.var(x)
        return [mu, var]

In [None]:
# Latent space dimension
d = 20

var_encoder = VariationalEncoder(encoded_space_dim=d)
decoder = Decoder(encoded_space_dim=d) # The decoder class is the same as before

lr = 1e-3
params_to_optimize = [
    {'params': var_encoder.parameters()},
    {'params': decoder.parameters()}]
optimizer = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-5)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

var_encoder.to(device)
decoder.to(device)

In [None]:
def gaussian_likelihood(x_hat, logscale, x):
    scale = torch.exp(logscale)
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)

    # measure prob of seeing image under p(x|z)
    likelihood = dist.log_prob(x)
    return likelihood.sum(dim=(1, 2, 3))

def KL_divergence(z, mu, std):
    #### Compute the KL divergence

    # 1. define the first two probabilities (in this case Normal for both)
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) 
    # we force the distribution to a standard normal distribution
    q = torch.distributions.Normal(mu, std)

    # 2. get the probabilities from the two distributions

    log_qzx = q.log_prob(z) 
    log_pz = p.log_prob(z)

    # 3. compute the kl divergence

    kl = (log_qzx - log_pz)
    kl = kl.sum(-1)
    return kl

In [None]:
### Training function
def train_epoch(var_encoder, decoder, device, dataloader, optimizer):
    # Set train mode for both the encoder and the decoder
    var_encoder.train()
    decoder.train()
    train_loss = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for x, _ in dataloader: 
        # Move tensor to the proper device
        x = x.to(device)
        # Get the mean and std vectors from the encoder
        [mean, log_var] = var_encoder(x)
        # Sample the distribution to obtain a differentiable function
        std = torch.exp(log_var / 2) # this is needed to ensure that the std is positive
        q = torch.distributions.Normal(mean, std)
        z = q.rsample()
        # Obtain the output of the encoder
        x_hat = decoder(z)
        
        # Evaluate reconstruction loss + KL divergence
        recon_loss = gaussian_likelihood(x_hat, var_encoder.log_scale, x)
        kl = KL_divergence(z, mean, std)
        loss = (-recon_loss + kl).mean()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
    print('\t train loss: %f' % (np.mean(train_loss)))
    return np.mean(train_loss)

In [None]:
### Evaluation function

def test_epoch(var_encoder, decoder, device, dataloader):
    # Set evaluation mode
    var_encoder.eval()
    decoder.eval()
    val_loss = []
    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(device)
            
            # Get the mean and std vectors from the encoder
            [mean, log_var] = var_encoder(x)
            
            # Sample the distribution to obtain a differentiable function
            std = torch.exp(log_var / 2) # this is needed to ensure that the std is positive
            q = torch.distributions.Normal(mean, std)
            z = q.rsample()
            
            # Obtain the output of the encoder
            x_hat = decoder(z)

            # Evaluate reconstruction loss + KL divergence
            recon_loss = gaussian_likelihood(x_hat, var_encoder.log_scale, x)
            kl = KL_divergence(z, mean, std)
            loss = (-recon_loss + kl).mean()
            
            val_loss.append(loss.item())

    return np.mean(val_loss)
    

### Training loop

In [None]:
if not trained:
    ### Training cycle
    patience = 3
    best = np.infty
    num_epochs = 100
    for epoch in range(num_epochs):
        print('EPOCH %d/%d' % (epoch + 1, num_epochs))
        ### Training (use the training function)
        train_epoch(
            var_encoder=var_encoder, 
            decoder=decoder, 
            device=device, 
            dataloader=train_dataloader,  
            optimizer=optimizer)
        ### Validation  (use the testing function)
        val_loss = test_epoch(
            var_encoder=var_encoder, 
            decoder=decoder, 
            device=device, 
            dataloader=test_dataloader)
        # Print Validationloss
        print('\n\n\t VALIDATION - EPOCH %d/%d - loss: %f\n\n' % (epoch + 1, num_epochs, val_loss))
        if val_loss < best:
            best = val_loss
            patience = 3
        else:
            patience -= 1
        if patience == 0:
            print("Learning stopped")
            break
        ### Plot progress
        # Get the output of a specific image (the test image at index 0 in this case)
        img = test_dataset[0][0].unsqueeze(0).to(device)
        var_encoder.eval()
        decoder.eval()
        with torch.no_grad():

            [mean, log_var] = var_encoder(img)
            std = torch.exp(log_var / 2) # this is needed to ensure that the std is positive
            q = torch.distributions.Normal(mean, std)
            z = q.rsample()
            rec_img  = decoder(z)

        # Plot the reconstructed image
        fig, axs = plt.subplots(1, 2, figsize=(12,6))
        axs[0].imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[0].set_title('Original image')
        axs[1].imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[1].set_title('Reconstructed image (EPOCH %d)' % (epoch + 1))
        plt.tight_layout()
        plt.pause(0.1)
        # Save figures
        os.makedirs('autoencoder_progress_%d_features' % d, exist_ok=True)
        #fig.savefig('autoencoder_progress_%d_features/epoch_%d.jpg' % (d, epoch + 1))
        plt.show()
        plt.close()

        # Save network parameters
        torch.save(var_encoder.state_dict(), 'var_encoder_params.pth')
        torch.save(decoder.state_dict(), 'var_decoder_params.pth')
else:
    # Load network parameters
    
    var_encoder.load_state_dict(torch.load('var_encoder_params.pth', map_location=device))
    decoder.load_state_dict(torch.load('var_decoder_params.pth', map_location=device))

In [None]:
## Some examples of reconstructed images
indices = np.random.randint(len(test_dataset), size=8)
subset = torch.utils.data.Subset(test_dataset, indices)
testloader_subset = DataLoader(subset, batch_size=1, num_workers=0, shuffle=False)
    
var_encoder.eval()
decoder.eval()
fig, axs = plt.subplots(2, 4, figsize=(15,15))
axs = axs.flatten()
ax_n = 0
## Iterate trough the samples in the test dataset
iterator = iter(testloader_subset)
loop = True
while loop:
    try:
        data, label = next(iterator)
    except StopIteration:
        loop = False
    else:
        with torch.no_grad():
            # In this case we need to sample from a normal distribution 
            [mean, log_var] = var_encoder(data)
            std = torch.exp(log_var / 2) # this is needed to ensure that the std is positive
            q = torch.distributions.Normal(mean, std)
            z = q.rsample()
            out  = decoder(z)
            
        axs[ax_n].set_xticks([])
        axs[ax_n].set_yticks([])
        axs[ax_n].imshow(out.cpu().squeeze().numpy(), cmap='gist_gray')
        axs[ax_n].set_title(label_names[label[0].item()], fontsize = 14)
        ax_n += 1
        plt.tight_layout()
#plt.savefig("Examples_variational.pdf", format='pdf')

### Latent space analysis

In [None]:
encoded_samples = []
for sample in tqdm(test_dataset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    var_encoder.eval()
    with torch.no_grad():
        encoded_img, _  = var_encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)

encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples

In [None]:
pca = PCA(n_components=2).fit(encoded_samples[:-2])
encoded_samples_pca = pca.transform(encoded_samples)

In [None]:
fig = px.scatter(encoded_samples_pca, x=0, y=1,
                 color=encoded_samples.label.astype(str),
                 labels={'0': 'Feature 1', '1': 'Feature 2', 'Color': 'Class'})

fig.for_each_trace(lambda t: t.update(name = dictionary[t.name]))
#fig.write_image("PCA_variational.pdf")

In [None]:
# Based on the pca find two internal representation
coat = np.array([0.5, 1.3]) 
boot = np.array([-5.5, -1])
sample1 = torch.tensor(np.dot(coat, pca.components_) + pca.mean_)[:-1]
sample2 = torch.tensor(np.dot(boot, pca.components_) + pca.mean_)[:-1]
samples = [sample1, sample2]
encoder.eval()
decoder.eval()
encoder.float()
decoder.float()
fig, axs = plt.subplots(1, 2, figsize=(15,15))
axs = axs.flatten()
ax_n = 0
while ax_n < 2:
    with torch.no_grad():
        out = decoder(samples[ax_n].float().unsqueeze(0).to(device)) # add the batch dimension
    axs[ax_n].set_xticks([])
    axs[ax_n].set_yticks([])
    axs[ax_n].imshow(out.cpu().squeeze().numpy(), cmap='gist_gray')
    ax_n += 1
    plt.tight_layout()
#plt.savefig("new_examples_variational.pdf", format='pdf')