In [1]:
import torch
import torchvision
import torch.optim as optim
import argparse
import matplotlib
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

import torchvision.transforms as transforms

from tqdm import tqdm

from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image

matplotlib.style.use('ggplot')

In [2]:
# define a simple linear VAE
class LinearVAE(nn.Module):
    def __init__(self):
        super(LinearVAE, self).__init__()
 
        # encoder
        self.enc1 = nn.Linear(in_features=784, out_features=512)
        self.enc2 = nn.Linear(in_features=512, out_features=32)
 
        # decoder 
        self.dec1 = nn.Linear(in_features=32, out_features=512)
        self.dec2 = nn.Linear(in_features=512, out_features=784)
    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample
 
    def forward(self, x):
        # encoding
        x = F.relu(self.enc1(x))
        x = self.enc2(x)
        # get `mu` and `log_var`
        mu = x
        log_var = x
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
 
        # decoding
        x = F.relu(self.dec1(z))
        reconstruction = torch.sigmoid(self.dec2(x))
        return reconstruction, mu, log_var


In [3]:
epochs = 10
batch_size = 64
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# transforms
transform = transforms.Compose([
    transforms.ToTensor()])

In [7]:
# train and validation data
train_data = datasets.MNIST(
    root='../input/data',
    train=True,
    download=True,
    transform=transform
)
val_data = datasets.MNIST(
    root='../input/data',
    train=False,
    download=True,
    transform=transform
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../input/data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../input/data/MNIST/raw/train-images-idx3-ubyte.gz to ../input/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../input/data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../input/data/MNIST/raw/train-labels-idx1-ubyte.gz to ../input/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../input/data/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../input/data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../input/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../input/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../input/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../input/data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [8]:
# training and validation data loaders
train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True
)
val_loader = DataLoader(
    val_data,
    batch_size=batch_size,
    shuffle=False
)

In [9]:
model = LinearVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss(reduction='sum')

In [10]:
def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [11]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
        data, _ = data
        data = data.to(device)
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        bce_loss = criterion(reconstruction, data)
        loss = final_loss(bce_loss, mu, logvar)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [12]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            data, _ = data
            data = data.to(device)
            data = data.view(data.size(0), -1)
            reconstruction, mu, logvar = model(data)
            bce_loss = criterion(reconstruction, data)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            if i == int(len(val_data)/dataloader.batch_size) - 1:
                num_rows = 8
                both = torch.cat((data.view(batch_size, 1, 28, 28)[:8], 
                                  reconstruction.view(batch_size, 1, 28, 28)[:8]))
                save_image(both.cpu(), f"./output{epoch}.png", nrow=num_rows)
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [13]:
train_loss = []
val_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, val_loader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

  0%|          | 1/937 [00:00<02:39,  5.86it/s]

Epoch 1 of 10


938it [00:07, 117.66it/s]                         
157it [00:00, 168.63it/s]                         
  1%|          | 10/937 [00:00<00:09, 97.44it/s]

Train Loss: 227.5914
Val Loss: 182.9843
Epoch 2 of 10


938it [00:07, 124.21it/s]                         
157it [00:00, 180.87it/s]                         
  1%|▏         | 12/937 [00:00<00:08, 112.58it/s]

Train Loss: 172.6994
Val Loss: 164.3475
Epoch 3 of 10


938it [00:07, 121.93it/s]                         
157it [00:00, 178.12it/s]                         
  1%|▏         | 14/937 [00:00<00:06, 132.92it/s]

Train Loss: 161.0181
Val Loss: 156.9682
Epoch 4 of 10


938it [00:07, 120.74it/s]                         
157it [00:00, 187.67it/s]                         
  1%|▏         | 14/937 [00:00<00:06, 132.12it/s]

Train Loss: 155.3030
Val Loss: 152.3927
Epoch 5 of 10


938it [00:07, 124.72it/s]                         
157it [00:00, 176.31it/s]                         
  1%|▏         | 13/937 [00:00<00:07, 121.47it/s]

Train Loss: 151.4590
Val Loss: 149.2042
Epoch 6 of 10


938it [00:07, 123.08it/s]                         
157it [00:00, 182.26it/s]                         
  1%|▏         | 13/937 [00:00<00:07, 128.00it/s]

Train Loss: 148.5961
Val Loss: 146.7362
Epoch 7 of 10


938it [00:07, 119.91it/s]                         
157it [00:00, 181.98it/s]                         
  1%|▏         | 13/937 [00:00<00:07, 121.37it/s]

Train Loss: 146.5291
Val Loss: 144.5970
Epoch 8 of 10


938it [00:07, 118.02it/s]                         
157it [00:00, 178.87it/s]                         
  1%|▏         | 13/937 [00:00<00:07, 123.47it/s]

Train Loss: 144.7934
Val Loss: 143.1517
Epoch 9 of 10


938it [00:07, 125.24it/s]                         
157it [00:00, 182.34it/s]                         
  1%|▏         | 13/937 [00:00<00:07, 127.21it/s]

Train Loss: 143.5308
Val Loss: 142.0092
Epoch 10 of 10


938it [00:07, 122.96it/s]                         
157it [00:00, 165.23it/s]                         

Train Loss: 142.2463
Val Loss: 141.1507



