In [1]:
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split

import torchvision.transforms as transforms
from torchvision.datasets import MNIST

# Constants

In [52]:
latent_dims = 10
num_epochs = 100
variational_beta = 1
batch_size = 128
capacity = 64
learning_rate = 1e-3

# Load Dataset

In [53]:
img_transforms = transforms.Compose([
    transforms.ToTensor()
])

dataset = MNIST(root='data/', train=True, 
                transform=img_transforms, download=True)

# Training validation & test dataset
train_ds, val_ds = random_split(dataset, [50000, 10000])

# test_ds = MNIST(root='data/', train=False, 
#                 transform=transforms.ToTensor())


# Dataloaders
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=2, pin_memory=True)

# Move Data and model to device

In [54]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else: 
        return torch.device('cpu')

In [55]:
device = get_default_device()
device

device(type='cpu')

In [56]:
def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [57]:
class DeviceDataLoader():

    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)
    
    def __len__(self):
        return len(self.dl)

In [58]:
train_loader = DeviceDataLoader(train_loader, device)
val_loader = DeviceDataLoader(val_loader, device)

## VAE

In [74]:
class Encoder(nn.Module):
    
    def __init__(self):
        super().__init__() 
        c = capacity
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # c*14*14
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # c*7*7
        self.fc_mu = nn.Linear(in_features=c*2*7*7, out_features=latend_dims)
        self.fc_logvar = nn.Linear(in_features=c*2*7*7, out_features=latend_dims)

    
    def forward(self, xb):
        xb = self.conv1(xb)
        xb = F.relu(xb)
        xb = self.conv2(xb)
        xb = F.relu(xb)
        xb = xb.view(xb.size(0), -1)
        x_mu = self.fc_mu(xb)
        x_logvar = self.fc_logvar(xb)
        return x_mu, x_logvar
        

class Decoder(nn.Module):
    
    def __init__(self):
        super().__init__() 
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
            
    def forward(self, xb):
        xb = self.fc(xb)
        xb = xb.view(xb.size(0), capacity*2, 7, 7)
        xb = F.relu(self.conv2(xb))
        xb = torch.sigmoid(self.conv1(xb))
        return xb
    

    
def vae_loss(reconstruct, og, mu, logvar):
    reconstruction_loss = F.binary_cross_entropy(reconstruct.view(-1, 784), og.view(-1, 784), reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - torch.pow(mu, 2) - torch.exp(logvar))
    return reconstruction_loss + variational_beta * kl_divergence
    

class VariationalAutoEncoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def latent_sample(self, mu, logvar):
        sigma = torch.exp(torch.mul(logvar, 0.5))
        eps = torch.randn_like(sigma)
        sample = mu + (sigma * eps)
        return sample
    
    def forward(self, xb):
        latent_mu, latent_std = self.encoder(xb)
        latent_ = self.latent_sample(latent_mu, latent_std)
        reconstruction = self.decoder(latent_)
        return reconstruction, latent_mu, latent_std
    
    def train_step(self, batch):
        input_batch, _ = batch
        input_batch_reconstruct, batch_mu, batch_std = self(input_batch)
        loss = vae_loss(input_batch_reconstruct, input_batch, batch_mu, batch_std)
        return {"loss": loss}
    
    def valid_step(self, batch):
        with torch.no_grad():
            val_batch, _ = batch
            val_batch_reconstruct, batch_mu, batch_std = self(val_batch)
            loss = vae_loss(val_batch_reconstruct, val_batch, batch_mu, batch_std)
            
        return {"val_loss": loss}

    def get_metrics_epoch_end(self, outputs, validation=True):
        if validation:
            loss_ = 'val_loss'
        else:
            loss_ = 'loss'
            
        batch_losses = [x[f'{loss_}'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   
        return {f'{loss_}': epoch_loss.item()}

    def epoch_end(self, epoch, result, num_epochs):
        if (epoch+1) % 5 == 0 or epoch == num_epochs-1:
            print(f"Epoch [{epoch+1}] -> loss: {result['loss']:.4f}, val_loss: {result['val_loss']:.4f}")


In [75]:
model = VariationalAutoEncoder()
to_device(model, device)
print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

VariationalAutoEncoder(
  (encoder): Encoder(
    (conv1): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (fc_mu): Linear(in_features=6272, out_features=10, bias=True)
    (fc_logvar): Linear(in_features=6272, out_features=10, bias=True)
  )
  (decoder): Decoder(
    (fc): Linear(in_features=10, out_features=6272, bias=True)
    (conv2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv1): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)
Number of parameters: 458901


# Train and Test

In [76]:
def evaluate(model, val_loader):
    outputs = [model.valid_step(batch) for batch in val_loader]
    return model.get_metrics_epoch_end(outputs, validation=True)

In [88]:
def fit(epochs, lr, model, train_loader, val_loader, opt_func=None):
    history = []
    
    if not opt_func:
        optimizer = torch.optim.SGD(model.parameters(), lr)
    else:
        optimizer = opt_func
    

    for epoch in range(epochs):
        # Training Phase 
        train_history = []
        model.train()
        
        for batch in train_loader:
            info = model.train_step(batch)
            loss = info['loss']
        
            # contains batch loss for training phase
            train_history.append(info)
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        train_result = model.get_metrics_epoch_end(train_history, validation=False)
        val_result = evaluate(model, val_loader)
        result = {**train_result, **val_result}

        model.epoch_end(epoch, result, epochs)
        history.append(result)
    return history

In [89]:
history = [evaluate(model, val_loader)]
history

[{'val_loss': 122643.7734375}]

In [90]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=1e-5)

In [None]:
history = fit(10, learning_rate, model, train_loader, val_loader, optimizer)