# Import tools

In [None]:
# built-in utilities
import copy
import os
import time
import datetime
import pickle

# data tools
import numpy as np

# pytorch 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision
from torchvision.utils import save_image
from torchvision import datasets, models, transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from IPython.display import Image
from IPython.core.display import Image, display


# visualization
import matplotlib.pyplot as plt
%matplotlib inline

# Report

Question 2 of homework 4 tasks us with using a variational auto-encoder  (VAE) to determine what augmentation was applied to a set of images containing handwritten 7’s from the MNIST dataset. We are given a training dataset containing 6,265 7’s and a test set containing 1,028 additional 7’s. We are not informed of the augmentation that was applied to the images.
The VAE architecture is as follows:
-	Encoder
    -	2-D convolutional layer
        -	Input = 1, Output = 64, Kernel size = 4, Stride = 2, Padding = 1
    -	2-D convolutional layer
        -	Input = 64, Output = 128, Kernel size = 4, Stride = 2, Padding = 1
    -	2-D convolutional layer
        -	Input = 128, Output = 256, Kernel size = 3, Stride = 2, Padding = 1
    -	2-D convolutional layer
        -	Input = 256, Output = 1024, Kernel size = 4, Stride = 1, Padding = 0
-	Decoder
    -	2-D transposed convolutional layer
        -	Input = 1024, Output = 512, Kernel size = 4, Stride = 1, Padding =
0    - 2-D transposed convolutional layer
        - Input = 512, Output = 256, Kernel size = 3, Stride = 2, Padding = 1
    - 2-D transposed convolutional layer
        - Input = 256, Output = 128, Kernel size = 4, Stride = 2, Padding = 1
    - 2-D transposed convolutional layer
        - Input = 128, Output = 1, Kernel size = 4, Stride = 2, Padding = 1
- Fully connected layers
    - Linear layer 1
        - Input = 1,024, Output = 512
    - Linear layer 2 (mean)
        - Input = 512, Output = 3
    - Linear layer 2 (standard deviation)
        - Input = 512, Output = 3
    - Linear layer 3
        - Input = 3, Output = 512
    - Linear Layer 4
        - Input = 3, Output = 1024

The number of latent dimensions is 3. Over 25 epochs, the training loss decreases from 187.2966 after the first epoch to 101.7138 at the 25th epoch, and the test loss decreases from 122.6691 after the first epoch to 100.9801 at the 25th epoch. Most of the loss decrease occurs for the binary cross entropy side of the loss function, as opposed to the KLD side of the loss function.

As for reconstructing the images to determine the effect that was added to the image, I am going to have to attempt this outside of the scope of the homework, as I have run out of time.

# Question 1

## Question 1, part 1 - 3




In [None]:
with open('hw4_tr7.pkl', 'rb') as f:
    train_data_raw = pickle.load(f)
    print("Train data shape: {}".format(train_data_raw.shape))

with open('hw4_te7.pkl', 'rb') as f:
    test_data_raw = pickle.load(f)
    print("Test data shape: {}".format(test_data_raw.shape))
    

In [None]:
class FullDataset(Dataset):
  def __init__(self, data):
        self.data = torch.from_numpy(data).unsqueeze(1)
        self.target = torch.ones(data.shape[0]) * 7
        
  def __len__(self):
        return len(self.data)

  def __getitem__(self, index):
        # Select sample
        data = self.data[index]
        target = self.target[index]
        return data, target

train_data = FullDataset(train_data_raw)
test_data = FullDataset(test_data_raw)


In [None]:
## load data
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)

train_samples, _ = next(iter(train_loader))
save_image(train_samples.data, "test.png", nrow=8)
Image("test.png")


## Question 1, part 4 - 6

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    def __init__(self, latent_dim=3):
        super(VAE, self).__init__()

        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            # Input = Batch size x 1 x 28 x 28
            
            # Batch size x 64 x 14 x 14
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Batch size x 128 x 7 x 7
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Batch size x 256 x 4 x 4
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Batch size x 1024 x 1 x 1
            nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(1024),            
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.decoder = nn.Sequential(
            # Input =  Batch size x 1024 x 1 x 1
            
            # Batch size x 512 x 4 x4
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            # Batch size x 256 x 7 x 7
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            # Batch size x 128 x 14 x 14
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # Batch size x 1 x 28 x 28
            nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=4, stride=2, padding=1),
            # nn.Tanh()
            nn.Sigmoid()
        )

        self.fc1 = nn.Linear(1024, 512)
        
        self.fc2m = nn.Linear(512, self.latent_dim)
        self.fc2s = nn.Linear(512, self.latent_dim)
        self.fc3 = nn.Linear(self.latent_dim, 512)

        self.fc4 = nn.Linear(512, 1024)

    def encode(self, x):
        h = self.encoder(x)
        h = self.fc1(h.view(-1, 1024))
        
        mu, logvar = self.fc2m(h), self.fc2s(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
    
    def decode(self, z):
        z = F.relu(self.fc3(z))
        z = self.fc4(z)
        
        z = z.view(-1, 1024, 1, 1)
        z = self.decoder(z)
        return z

    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        esp = torch.randn(*mu.size()).to(device)
        z = esp.mul(std).add_(mu)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar

In [None]:
## load data
batch_size = 32
epochs = 25

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

model = VAE(latent_dim=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def loss_fn(recon_x, x, mu, logvar, fn="BCE"):
    if fn == "BCE":
        BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
    elif fn == "MSE":
        BCE = F.mse_loss(recon_x, x, reduction="sum")
    
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD, BCE, KLD

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (images, _) in enumerate(train_loader):
        images = images.to(device)
        
        recon_images, mu, logvar = model(images)
        loss, bce, kld = loss_fn(recon_images, images, mu, logvar)
        
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 50 == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tTotal Loss: {:.5f}  \tBCE Loss: {:.5f} \tKLD Loss: {:.5f}".format(
                    epoch + 1,
                    batch_idx * batch_size,
                    len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / batch_size,
                    bce.item() / batch_size,
                    kld.item() / batch_size
                )
            )
    
    print('\nEpoch: {} Average train loss: {:.4f}'.format(epoch + 1, train_loss / len(train_loader.dataset)))
    torch.save(model.state_dict(), 'hw4_2.pt')

def test(epoch):
    model.eval()
    test_loss= 0
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(test_loader):
            images = images.to(device)
            
            recon_images, mu, logvar = model(images)
            loss, bce, kld = loss_fn(recon_images, images, mu, logvar)  
                        
            test_loss += loss
        
    test_loss /= len(test_loader.dataset)
    print('Epoch: {} Average test loss: {:.4f}\n\n'.format(epoch + 1, test_loss))

for epoch in range(epochs):
    train(epoch)
    test(epoch)

## Question 1, part 7 - 11
