<a href="https://colab.research.google.com/github/rsaran-BioAI/AGILE/blob/main/FashionMNIST_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The key differences between the VAE and a standard autoencoder are:

The encode function outputs two vectors (for mu and logvar). These represent the parameters of the Gaussian distribution from which we'll sample to obtain the latent vector.
The reparameterize function takes mu and logvar and performs the "reparameterization trick" to allow for gradient descent to work.
The loss_function is different. It combines reconstruction loss (BCE) with the KL divergence (KLD) between the learned latent distribution and the prior distribution, enforcing the latent space to assume a normal distribution which helps in generating new samples.
The rest of the training and testing loop remains the same, but now includes the loss due to the KL divergence.

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn.functional as F

In [2]:
# Prepare the data
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 14250588.46it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 228039.06it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 4210232.44it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 18662296.45it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw






In [4]:
# Variational Autoencoder Definition
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(28*28, 400)
        self.fc21 = nn.Linear(400, 20)  # Mean vector
        self.fc22 = nn.Linear(400, 20)  # Variance vector (log_var)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 28*28)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 28*28))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [5]:
# Instantiate the model
model = VAE()

# Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
    # KL divergence regularization term.
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [6]:
# Training Loop
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_fn(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Train Epoch: {batch_idx} Loss: {loss.item() / len(data):.6f}")

    print(f"====> Epoch: Train loss: {train_loss / len(dataloader.dataset):.6f}")

In [8]:
# Train the VAE
epochs = 10
for epoch in range(1, epochs + 1):
    train_loop(train_dataloader, model, loss_function, optimizer)

Train Epoch: 0 Loss: 550.158569
Train Epoch: 100 Loss: 292.511078
Train Epoch: 200 Loss: 276.104584
Train Epoch: 300 Loss: 278.024261
Train Epoch: 400 Loss: 260.506073
Train Epoch: 500 Loss: 273.031555
Train Epoch: 600 Loss: 259.550385
Train Epoch: 700 Loss: 255.688828
Train Epoch: 800 Loss: 256.657593
Train Epoch: 900 Loss: 261.208160
====> Epoch: Train loss: 274.672749
Train Epoch: 0 Loss: 252.666916
Train Epoch: 100 Loss: 252.590408
Train Epoch: 200 Loss: 254.644073
Train Epoch: 300 Loss: 259.366730
Train Epoch: 400 Loss: 244.908524
Train Epoch: 500 Loss: 260.511841
Train Epoch: 600 Loss: 249.384155
Train Epoch: 700 Loss: 245.076050
Train Epoch: 800 Loss: 247.370789
Train Epoch: 900 Loss: 253.362823
====> Epoch: Train loss: 251.649328
Train Epoch: 0 Loss: 246.994171
Train Epoch: 100 Loss: 246.609955
Train Epoch: 200 Loss: 247.089233
Train Epoch: 300 Loss: 254.366776
Train Epoch: 400 Loss: 241.017242
Train Epoch: 500 Loss: 256.787476
Train Epoch: 600 Loss: 245.493652
Train Epoch: 700

In [9]:
# Testing Loop
def test_loop(dataloader, model, loss_fn):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, _ in dataloader:
            recon_batch, mu, logvar = model(data)
            test_loss += loss_fn(recon_batch, data, mu, logvar).item()

    test_loss /= len(dataloader.dataset)
    print(f"====> Test set loss: {test_loss:.6f}")

In [10]:
test_loop(test_dataloader, model, loss_function)

====> Test set loss: 243.320911
