In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

import torchvision
from torchvision import datasets, transforms

In [None]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.EMNIST('./data', 'letters', train=True, download=True, transform=trans)
testset = datasets.EMNIST('./data', 'letters', train=False, download=True, transform=trans)

# Model Definition

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, id_downsample=None, stride=1):
        super(ResBlock, self).__init__()
        self.expansion = 4
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(out_channels*self.expansion),
            nn.ReLU()
        )

        self.identity_downsample = id_downsample

    def forward(self, input):
        identity = input
        output = self.layers(input)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        output += identity
        output = F.relu(output)

        return output

In [None]:
class ResNet(nn.Module):
    def __init__(self, layers, img_channels, out_dim):
        super(ResNet, self).__init__()
        assert len(layers) == 4 

        self.channels = 64
        self.entry = nn.Sequential(
            nn.Conv2d(img_channels, self.channels, kernel_size=7, stride=2,  padding=3),
            nn.BatchNorm2d(self.channels),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.layer1 = self._make_layers(layers[0], out_channels=64, stride=1)
        self.layer2 = self._make_layers(layers[1], out_channels=128, stride=2)
        self.layer3 = self._make_layers(layers[2], out_channels=256, stride=2)
        self.layer4 = self._make_layers(layers[3], out_channels=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512*4, out_dim)


    def _make_layers(self, num_blocks, out_channels, stride):
        id_downsample = None
        layers = []

        if stride != 1 or out_channels * 4 != self.channels:
            id_downsample = nn.Sequential(
                nn.Conv2d(self.channels, out_channels*4, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels * 4)
            )

        layers.append(ResBlock(self.channels, out_channels, id_downsample, stride))
        self.channels = out_channels * 4

        for _ in range(num_blocks - 1):
            layers.append(ResBlock(self.channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, input):
        x = self.entry(input)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        output = self.fc(x)

        return output

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, embed_dim, device):
        super(AutoEncoder, self).__init__()
        self.embed_dim = embed_dim
        self.device = device
        # Encoder
        self.encoder = ResNet([2, 2, 2, 2], 1, 300) # ResNet18
        self.enc_mu = nn.Linear(300, embed_dim)
        self.enc_ls = nn.Linear(300, embed_dim)

        # Decoder
        self.decoder = ResNet([2, 2, 2, 2], 1, 28*28) # ResNet18
        self.logsigma = nn.Parameter(torch.Tensor([0.0]))

    def encode(self, x):
        h = self.encoder(x)
        mu = self.enc_mu(h)
        ls = self.enc_ls(h)
        return mu, ls

    
    def decode(self, z):
        x_hat = self.decoder(z)
        return x_hat

    def forward(self, input):
        mu, logsigma = self.encode(input)
        q = torch.distributions.Normal(mu, logsigma.exp())
        z = q.rsample()
        
        xhat = self.decode(z.view(-1, 1, 8, 8))
        return xhat, mu, logsigma


# Training

In [None]:
def reconstruction_loss(xhat, x, logsigma):
    scale = torch.exp(logsigma)
    distribution = torch.distributions.Normal(xhat, scale)
    log_pxz = distribution.log_prob(x)

    return torch.mean(log_pxz.sum(dim=1))

def train_vae(batch_size, trainset, epochs=100, lr=0.001, embed_dim=10):
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if torch.cuda.is_available():
        print("Starting training on GPU...")
    else:
        print("Starting training on CPU...")

    vae = AutoEncoder(embed_dim, device)
    vae.to(device)

    optimizer = optim.Adam(vae.parameters(), lr=lr)
    trajectory = []
    rec_losses = []
    total_iters = 0

    for epoch in range(epochs):
        print('='*10 + f' Epoch {epoch} ' + '='*10)
        running_loss = 0.0
        iterations = 0
        for i, (data, _) in enumerate(trainloader):
            optimizer.zero_grad()

            data = data.to(device).float()
            xhat, mu, logsigma = vae(data)
            data = data.view(batch_size, -1)
            
            recon_loss = reconstruction_loss(xhat, data, vae.logsigma)
            kl_loss = torch.mean(-0.5 * torch.sum(1 + logsigma - mu ** 2 - logsigma.exp(), dim = 1), dim = 0)
            loss = kl_loss - recon_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            iterations += 1
            total_iters += 1

            if total_iters % 1000 == 0 or total_iters == 0:
                rec_loss = validate_reconstruction(vae)
                print(f"RECONSTRUCTION LOSS {rec_loss}")
                rec_losses.append(rec_loss)

            if i % 312 == 0:
                print(f'- Iteration {i} loss: {loss.item()}')
        
        trajectory.append(running_loss/iterations)
    
    return vae, trajectory, rec_losses


In [None]:
def resume_training(vae, batch_size, trainset, epochs=100, lr=0.001, embed_dim=10):
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if torch.cuda.is_available():
        print("Starting training on GPU...")
    else:
        print("Starting training on CPU...")

    vae.to(device)

    optimizer = optim.Adam(vae.parameters(), lr=lr)
    trajectory = []
    rec_losses = []
    total_iters = 0

    for epoch in range(epochs):
        print('='*10 + f' Epoch {epoch} ' + '='*10)
        running_loss = 0.0
        iterations = 0
        for i, (data, _) in enumerate(trainloader):
            optimizer.zero_grad()

            data = data.to(device).float()
            xhat, mu, logsigma = vae(data)
            data = data.view(batch_size, -1)
            
            recon_loss = reconstruction_loss(xhat, data, vae.logsigma)
            kl_loss = torch.mean(-0.5 * torch.sum(1 + logsigma - mu ** 2 - logsigma.exp(), dim = 1), dim = 0)
            loss = kl_loss - recon_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            iterations += 1
            total_iters += 1

            if total_iters % 1000 == 0 or total_iters == 0:
                rec_loss = validate_reconstruction(vae)
                print(f"RECONSTRUCTION LOSS {rec_loss}")
                rec_losses.append(rec_loss)

            if i % 312 == 0:
                print(f'- Iteration {i} loss: {loss.item()}')
        
        trajectory.append(running_loss/iterations)
    
    return vae, trajectory, rec_losses


In [None]:
def validate_reconstruction(vae):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    testloader = torch.utils.data.DataLoader(testset, batch_size=20,
                                          shuffle=True)
    rl = nn.MSELoss()
    loss = 0

    with torch.no_grad():
        for img, _ in testloader:
            img = img.to(device).float()
            rec, _, _ = vae(img)

            loss += rl(rec, img.view(20, -1)).item()
    
    return loss/len(testloader)

In [None]:
vae, trajectory, rec_loss = train_vae(100, trainset, lr=1e-5, epochs=81, embed_dim=64)

In [None]:
torch.save(vae.state_dict(), './vae_final.pt')

# Test and Evaluation

In [None]:
# Load pretrained model
autoencoder = AutoEncoder(64, torch.device('cuda'))
autoencoder.load_state_dict(torch.load('./vae_final.pt'))
autoencoder.to(device)
autoencoder.eval()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 2. Reconstruction

In [None]:
from matplotlib.pyplot import imshow, figure
from torchvision.utils import make_grid

# Reconstructions
def create_reconstructions(n_samples=20):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    testloader = torch.utils.data.DataLoader(testset, batch_size=n_samples,
                                          shuffle=True)
    test_imgs, _ = iter(testloader).next()
    test_imgs = test_imgs.to(device).float()

    with torch.no_grad():
        reconstruction, _, _ = vae(test_imgs)
        reconstruction = reconstruction.cpu()
        
    # Undo data normalization
    mean, std = np.array([0.5]), np.array([0.5])

    # Plot images
    viz = make_grid(reconstruction.reshape(n_samples, 1, 28, 28), nrow=5, padding = 2).numpy()* std + mean
    fig, ax = plt.subplots(figsize= (8,8), dpi=100)
    ax.imshow(np.transpose(viz, (1,2,0)))
    ax.grid(False)
    fig.savefig('./reconstruction.png')
    
create_reconstructions())

## 3. Scene Categorization

In [None]:
def scene_categorization(model, n_trials=25):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trials = np.random.binomial(1, 0.5, size=(n_trials, ))
    batch_size = 25 - trials.sum()
    testloader = torch.utils.data.DataLoader(testset, batch_size=n_trials,
                                          shuffle=True)
    
    itertest = iter(testloader)
    fake, _ = itertest.next()
    real, _ = itertest.next()
    
    with torch.no_grad():
        fake = fake.to(device).float()
        xhat, _, _ = model(fake)
        xhat = xhat.cpu().view(-1, 1, 28, 28)
        
    final = None
    i = 0
    j = 0
    for trial in trials:
        if trial:
            final = torch.cat((final, xhat[i])) if final is not None else xhat[i]
            i += 1
        else:
            final = torch.cat((final, real[j])) if final is not None else real[j]
            j += 1
    
    mean, std = np.array([0.5]), np.array([0.5])
    # Plot images
    viz = make_grid(final.reshape(n_trials, 1, 28, 28), nrow=5, padding = 2).numpy()* std + mean
    fig, ax = plt.subplots(figsize= (8,8), dpi=100)
    ax.imshow(np.transpose(viz, (1,2,0)))
    ax.grid(False)
    
    return trials