#Lecture 1 - AE, VAE, Latent space

# Theoretical Questions

### TQ1. How could you implement an AutoEncoder (or VAE) for Images? And time series?


Response -


### TQ2. Aside from the limitations listed in class, what are others limitations to AutoEncoders?

Response -


### TQ3. What effect does the bottleneck have in the network?

Response -


### TQ4. Imagine we want to build an AutoEncoder to fix missing data. Assume we have a reasonably sized dataset with enough non-faulty data to train this AutoEncoder and a massive dataset to fix.

Response -


# Code Exercises

**Dependencies**

In [None]:
!pip install pytorch-lightning

Load your drive using this (if necessary):

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Set up dependencies of the notebook

In [None]:
import datetime

import torch
import torch.nn as nn

import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning import seed_everything

import numpy as np

import pandas as pd

from sklearn.preprocessing import OneHotEncoder

import matplotlib.pyplot as plt

### Data Loading
Data loading is done here to simplify the coding exercises. The Data Loading and most of the training will remain the same. However, some changes will be required.

In [None]:
DATA_PATH = 'mushrooms.csv'
SEED = 42
seed_everything(seed=SEED) # Set seed for reproducibility,

Load the dataset

In [None]:
mush_df = pd.read_csv(DATA_PATH) # Drop the class,
mush_df.head()

Now, quicly perform the One Hot Encoding with sklearn `sklearn.preprocessing.OneHotEncoder`



In [None]:
ohe = OneHotEncoder(sparse_output=False)

mush_df_train = mush_df.drop('class', axis=1).copy()
new_data = ohe.fit_transform(mush_df_train)
mush_df_train = pd.DataFrame(new_data, columns=ohe.get_feature_names_out())
mush_df_train


Torch dataset

In [None]:
class MushroomDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return torch.tensor(self.df.iloc[idx].values, dtype=torch.float32)

Finally, declare the `pl.LightningDataModule`

In [None]:
class MushroomModule(pl.LightningDataModule):
    def __init__(self, df, batch_size=16, val_size=0.2):
        super().__init__()
        self.batch_size = batch_size
        self.val_size = val_size
        self.full_len = len(df)
        self.df = df

    def setup(self, stage=None):
        # We can't perform numerical evaluations! Only train + val
        train_end = int((1 - self.val_size) * self.full_len)
        self.train_data = MushroomDataset(self.df[:train_end])
        self.val_data = MushroomDataset(self.df[train_end:])

    def collate_fn(self, batch):
        features = torch.stack(batch, axis=0)  # [batch_size, input_size]
        return features

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_data, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn)

## CQ1 (✰). Improving the autoencoder

In the Lecture we made a basic AutoEncoder, however, this autoencoder is severely flawed and can be improved upon with better design decisions along the Encoder-Decoder architecture.

1.   Labels are one hot encodings, ranging from 0 to 1; if the network outputs conform to these ranges, it can be expected to perform better. Specially with MSE loss. There's a catch to this, though...
2.   ... Several positions of the output vector match with a single feature. Measuring MSE is ill-fitted, as there are several classification tasks we could use a loss term for each classification task at the same time.

Intorduce any desired changes to the network. Generalize the number of layers it may have, include regularization... include your own improvements.



In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim, bottleneck):
        super(Encoder, self).__init__()

    def forward(self, x):
        pass

In [None]:
class Decoder(nn.Module):
    def __init__(self, bottleneck, hidden_dim, output_size):
        super(Decoder, self).__init__()

    def forward(self, x):
        pass

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_size, hidden_dim, bottleneck):
        super(AutoEncoder, self).__init__()

    def forward(self, x):
        pass

### Training

Introduce any new change you want in the `compute_batch` method.

In [None]:
class MushroomCompressor(pl.LightningModule):
    def __init__(self, model, categories, learning_rate=1e-3, weight_decay=0.):
        super(MushroomCompressor,self).__init__()
        self.save_hyperparameters() # Save Hyperparams
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        self.nrmse = torchmetrics.NormalizedRootMeanSquaredError()
        self.mse = nn.MSELoss()
        self.categories = categories # New property, required for advanced train

    def forward(self, x):
        return self.model(x)

    def compute_batch(self, batch, split='train'):
        pass

    def training_step(self, batch, batch_idx):
        return self.compute_batch(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self.compute_batch(batch, 'val')

    def predict_step(self, batch, batch_idx):
        return self(batch)[-1]

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate,
                                 weight_decay=self.weight_decay) # self.parameters() son los parámetros del modelo

In [None]:
#@title Hyper-parameters

LEARNING_RATE = 1e-3 #@param {type:"number"}
WEIGHT_DECAY = 0. #@param {type:"number"}
BATCH_SIZE = 8 # @param ["2","4","8","16","32"] {"type":"raw"}
MAX_EPOCHS = 10 # @param {"type":"slider","min":0,"max":100,"step":1}
HIDDEN_DIM = 32 # @param {"type":"slider","min":0,"max":128,"step":1}
LATENT_DIM = 3 # @param {"type":"slider","min":0,"max":16,"step":1}

SAVE_DIR = f'lightning_logs/sales/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'


#### Train loop

In [None]:
# DataModule
data_module = MushroomModule(mush_df_train, batch_size=BATCH_SIZE)

# Model
model = AutoEncoder(input_size=mush_df_train.values.shape[1], hidden_dim=HIDDEN_DIM, bottleneck=LATENT_DIM)

# LightningModule
module = MushroomCompressor(model, ohe.categories_, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Callbacks
early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss', # monitorizamos la pérdida en el conjunto de validación
    mode='min',
    patience=5, # número de epochs sin mejora antes de parar
    verbose=False, # si queremos que muestre mensajes del estado del early stopping
)
model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss', # monitorizamos la pérdida en el conjunto de validación
    mode='min', # queremos minimizar la pérdida
    save_top_k=1, # guardamos solo el mejor modelo
    dirpath=SAVE_DIR, # directorio donde se guardan los modelos
    filename=f'best_model' # nombre del archivo
)

callbacks = [early_stopping_callback, model_checkpoint_callback]

# Loggers
csv_logger = pl.loggers.CSVLogger(
    save_dir=SAVE_DIR,
    name='metrics',
    version=None
)

loggers = [csv_logger] # se pueden poner varios loggers (mirar documentación)

# Trainer
trainer = pl.Trainer(max_epochs=MAX_EPOCHS, accelerator='gpu', callbacks=callbacks,
                     logger=loggers, precision='bf16')

trainer.fit(module, data_module)

### Visualization
Lets explore the latent space!

In [None]:
embeddings = trainer.predict(module, torch.utils.data.DataLoader(MushroomDataset(mush_df_train), batch_size=16))
embeddings = torch.cat(embeddings).to(torch.float).cpu().numpy()
embeddings

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# You can visualize a 3d space with this!
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

classes = mush_df['class']

scatter = ax.scatter(embeddings[:, 0], embeddings[:, 1], embeddings[:, 2], c=classes.astype('category').cat.codes)

ax.set_xlabel("Latent Dimension 1")
ax.set_ylabel("Latent Dimension 2")
ax.set_zlabel("Latent Dimension 3")
ax.set_title("Latent Space Representation of Mushrooms")

handles, labels = scatter.legend_elements()
ax.legend(handles, classes.unique(), title="Classes")

plt.show()

## CQ2 (✰✰). Implement the VAE for the mushroom dataset

We have seen that VAEs overcome the problems of AE by generating a continuous latent space. We implement here this VAE model. The VAE generates the average and deviation of a distribution, instead of a point. We need to implement an encoder that performs exactly that task.

We have to build $σ(x)$ and $μ(x)$, these transformations operate on $x$ (the original latent representation), the earlier point that was generated by the encoder. Use a Linear layer to let the model decide where's the average of $x$ and another Linear layer to let the model decide what's the deviation of $x$.

**Note**: For better results we use the logarithmic deviation. This, however, does not affect the encoder.



### VAE definition

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim, bottleneck):
        super(Encoder, self).__init__()
        self.encoder_layer = nn.Sequential(
            nn.Linear(input_size, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, bottleneck),
            nn.BatchNorm1d(bottleneck),
        )
        self.mu = nn.Linear(bottleneck, bottleneck)
        self.sigma = nn.Linear(bottleneck, bottleneck)

    def forward(self, x):
        pass

Now, the decoder will continue being the same, it receives a latent representation and generates a mushroom. Same as before.

In [None]:
class Decoder(nn.Module):
    def __init__(self, bottleneck, hidden_dim, output_size):
        super(Decoder, self).__init__()
        self.decoder_layer = nn.Sequential(
            nn.Linear(bottleneck, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_size),
        )

    def forward(self, x):
        pass

However, there is a mismatch now. The encoder outputs a distribution of points, while the decoder takes a single point. Thus, the Encoder-Decoder should sample the distribution of the encoder using $μ(x)$ and $σ(x)$!

However, sampling is non-differentiable, we need the **reparametrization trick**. Normally we would use `torch.normal(mean=μ(x), std=σ(x))`, however we will use a different expression.

The reparamerization trick involves using `torch.randn(bottleneck)` to obtain a random vector to operate with the average and deviation.

$$x_{sampled} = μ(x) + N(0,1) \cdot e^{0.5*σ(x)}  $$

Or, the exact same would be multiplying the deviation ($\sigma(x)$) by a standard normal random distribution ($N(0,1)$) then summing the mean $\mu(x)$.

To train we also need to retrieve the deviation and mean!

In [None]:
class VAE(nn.Module):
    def __init__(self, input_size, hidden_dim, bottleneck):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_size, hidden_dim, bottleneck)
        self.decoder = Decoder(bottleneck, hidden_dim, input_size)

    def forward(self, x):
        pass

### Training

The new loss term is different from before. We use:

$$ℒ(x',x) = ℒ_{rec}(x', x) + λℒ_{kl}(\mu(x), \sigma(x))$$

This is a sum of the reconstruction loss (Whatever loss we choose, MSE, CE, etc) and the Kullback–Leibler loss (KL Div.). $x$ are the inputs, $x'$ is the reconstruction.

Modify `_rec_loss(self, preds, targets)` and  `_kl_loss(self, mu, sigma)` to build the new `compute_batch(self, batch)`.

You can use either the torch implementation of the [KL-Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss) or code it yourself!

In [None]:
class MushroomGenerator(pl.LightningModule):
    def __init__(self, model, categories, learning_rate=1e-3, weight_decay=0., lambd=1.):
        super(MushroomGenerator,self).__init__()
        self.save_hyperparameters() # Save Hyperparams
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.model = model
        self.reconstruction = nn.CrossEntropyLoss()
        self.nrmse = torchmetrics.NormalizedRootMeanSquaredError()
        self.mse = nn.MSELoss()
        self.lambd = lambd # Weights the influence of the KLDiv
        self.categories = categories

    def forward(self, x):
        return self.model(x)

    def _rec_loss(self, preds, targets):
        pass

    def _kl_loss(self, mu, sigma):
        pass

    def compute_batch(self, batch, split='train'):
        pass

    def training_step(self, batch, batch_idx):
        return self.compute_batch(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self.compute_batch(batch, 'val')

    def predict_step(self, batch, batch_idx):
        return self(batch)[-1][0]

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate,
                                 weight_decay=self.weight_decay)

In [None]:
#@title Hyper-parameters

LEARNING_RATE = 1e-3 #@param {type:"number"}
WEIGHT_DECAY = 0. #@param {type:"number"}
BATCH_SIZE = 8 # @param ["2","4","8","16","32"] {"type":"raw"}
MAX_EPOCHS = 10 # @param {"type":"slider","min":0,"max":100,"step":1}
HIDDEN_DIM = 32 # @param {"type":"slider","min":0,"max":128,"step":1}
LATENT_DIM = 3 # @param {"type":"slider","min":0,"max":16,"step":1}
LAMBDA = 1 #@param {type:"number"}

SAVE_DIR = f'lightning_logs/sales/{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'


### Run!

In [None]:
# DataModule
data_module = MushroomModule(mush_df_train, batch_size=BATCH_SIZE)

# Model
model = VAE(input_size=mush_df_train.values.shape[1], hidden_dim=HIDDEN_DIM, bottleneck=LATENT_DIM)

# LightningModule
module = MushroomGenerator(model, ohe.categories_,
                           learning_rate=LEARNING_RATE,
                           weight_decay=WEIGHT_DECAY,
                           lambd=LAMBDA)

# Callbacks
early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss', # monitorizamos la pérdida en el conjunto de validación
    mode='min',
    patience=5, # número de epochs sin mejora antes de parar
    verbose=False, # si queremos que muestre mensajes del estado del early stopping
)
model_checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss', # monitorizamos la pérdida en el conjunto de validación
    mode='min', # queremos minimizar la pérdida
    save_top_k=1, # guardamos solo el mejor modelo
    dirpath=SAVE_DIR, # directorio donde se guardan los modelos
    filename=f'best_model' # nombre del archivo
)

callbacks = [early_stopping_callback, model_checkpoint_callback]

# Loggers
csv_logger = pl.loggers.CSVLogger(
    save_dir=SAVE_DIR,
    name='metrics',
    version=None
)

loggers = [csv_logger] # se pueden poner varios loggers (mirar documentación)

# Trainer
trainer = pl.Trainer(max_epochs=MAX_EPOCHS, accelerator='gpu', callbacks=callbacks,
                     logger=loggers, precision='bf16')

trainer.fit(module, data_module)

In [None]:
embeddings = trainer.predict(module, torch.utils.data.DataLoader(MushroomDataset(mush_df_train), batch_size=16))
embeddings = torch.cat(embeddings).to(torch.float).cpu().numpy()
embeddings

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# You can visualize a 3d space with this!
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

classes = mush_df['class']

scatter = ax.scatter(embeddings[:, 0], embeddings[:, 1], embeddings[:, 2], c=classes.astype('category').cat.codes)

ax.set_xlabel("Latent Dimension 1")
ax.set_ylabel("Latent Dimension 2")
ax.set_zlabel("Latent Dimension 3")
ax.set_title("Latent Space Representation of Mushrooms")

handles, labels = scatter.legend_elements()
ax.legend(handles, classes.unique(), title="Classes")

plt.show()

At this point you won't notice any difference between the AE and the VAE... and that's okay. It's difficult to understand the point of the VAE with categorical data, it will be apparent in the next Lecture with images.