In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# Define hyperparams
image_size = 784 # 28 x 28
hidden_dim = 400
latent_dim = 20
batch_size = 128
epochs = 30

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data',train=True,transform=transforms.ToTensor(),download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_dataset = torchvision.datasets.MNIST(root='../../data',train=False,transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 479kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.95MB/s]


In [5]:
# dir to store reconstructed and sampled images
sample_dir = 'results'
if not os.path.exists(sample_dir):
  os.makedirs(sample_dir)

In [20]:
# VAE model
class VAE(nn.Module):
  def __init__(self):
    ## since we want to inherit everything from nn.Module
    super(VAE,self).__init__()
    self.fc1 = nn.Linear(image_size,hidden_dim)
    self.fc2_mean = nn.Linear(hidden_dim,latent_dim)
    self.fc2_logvar = nn.Linear(hidden_dim,latent_dim)
    self.fc3 = nn.Linear(latent_dim,hidden_dim)
    self.fc4 = nn.Linear(hidden_dim,image_size)

  def encode(self,x):
    h = F.relu(self.fc1(x))
    mu = self.fc2_mean(h)
    log_var = self.fc2_logvar(h)
    return mu, log_var

  def reparameterize(self, mu, logvar):
    # for numerical stability
    std = torch.exp(logvar/2)
    # element-wise multiplication so same shape
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z):
    h = F.relu(self.fc3(z))
    # i/p and o/p values should be between 0 and 1
    out = torch.sigmoid(self.fc4(h))
    return out

  def forward(self,x):
    # x: (batch_size,1,28,28) -> (batch_size,784)
    mu, logvar = self.encode(x.view(-1, image_size))
    z = self.reparameterize(mu, logvar)
    reconstructed = self.decode(z)
    return reconstructed, mu, logvar

# Define model and optimizer
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [21]:
# Define loss = bce + kld
def loss_function(reconstructed_image, original_image, mu, logvar):
  bce = F.binary_cross_entropy(reconstructed_image, original_image.view(-1,784), reduction='sum')
  kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
  return bce + kld


def train(epoch):
  model.train()
  train_loss = 0
  for i, (images, _) in enumerate(train_loader):
    images = images.to(device)
    reconstructed, mu, logvar = model(images)
    loss = loss_function(reconstructed, images, mu, logvar)
    optimizer.zero_grad() # don't want to accumulate
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    if i%100 == 0:
      print("Train Epoch {} [Batch {}/{}]\tLoss: {:.3f}".format(epoch,i,len(train_loader),loss.item()/len(images)))
  print("=======> Epoch {}, Average Loss: {:.3f}".format(epoch, train_loss/len(train_loader.dataset)))

def test(epoch):
  model.eval()
  test_loss = 0
  with torch.no_grad():
    for batch_idx, (images, _) in enumerate(test_loader):
      images = images.to(device)
      reconstructed, mu, logvar = model(images)
      loss = loss_function(reconstructed, images, mu, logvar)
      test_loss += loss.item()
      if batch_idx == 0: # at each epoch in the beginning of training
        comparison = torch.cat([images[:5], reconstructed.view(batch_size,1,28,28)[:5]])
        save_image(comparison.cpu(),'results/reconstruction_' + str(epoch) + '.png',nrow = 5)

  print("=======> Average Tess Loss: {:.3f}".format(test_loss/len(test_loader.dataset)))


In [22]:
for epoch in range(1,epochs+1):
  train(epoch)
  test(epoch)
  with torch.no_grad():
    # get rid of the encoder and sample some values from the gaussian distribution and feed it to the decoder to generate new samples
    sample = torch.randn(64,20).to(device)
    generated = model.decode(sample).cpu()
    save_image(generated.view(64,1,28,28),'results/sample_' + str(epoch) + '.png')

Train Epoch 1 [Batch 0/469]	Loss: 546.414
Train Epoch 1 [Batch 100/469]	Loss: 182.216
Train Epoch 1 [Batch 200/469]	Loss: 154.341
Train Epoch 1 [Batch 300/469]	Loss: 142.859
Train Epoch 1 [Batch 400/469]	Loss: 130.953
Train Epoch 2 [Batch 0/469]	Loss: 127.327
Train Epoch 2 [Batch 100/469]	Loss: 125.744
Train Epoch 2 [Batch 200/469]	Loss: 124.687
Train Epoch 2 [Batch 300/469]	Loss: 119.464
Train Epoch 2 [Batch 400/469]	Loss: 115.358
Train Epoch 3 [Batch 0/469]	Loss: 117.581
Train Epoch 3 [Batch 100/469]	Loss: 115.210
Train Epoch 3 [Batch 200/469]	Loss: 118.075
Train Epoch 3 [Batch 300/469]	Loss: 118.622
Train Epoch 3 [Batch 400/469]	Loss: 113.291
Train Epoch 4 [Batch 0/469]	Loss: 110.082
Train Epoch 4 [Batch 100/469]	Loss: 115.408
Train Epoch 4 [Batch 200/469]	Loss: 115.385
Train Epoch 4 [Batch 300/469]	Loss: 112.523
Train Epoch 4 [Batch 400/469]	Loss: 112.028
Train Epoch 5 [Batch 0/469]	Loss: 109.684
Train Epoch 5 [Batch 100/469]	Loss: 115.846
Train Epoch 5 [Batch 200/469]	Loss: 111.84

In [23]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [24]:
!zip -r results.zip results

  adding: results/ (stored 0%)
  adding: results/reconstruction_7.png (deflated 1%)
  adding: results/reconstruction_3.png (deflated 2%)
  adding: results/reconstruction_2.png (deflated 3%)
  adding: results/reconstruction_18.png (deflated 2%)
  adding: results/sample_1.png (deflated 2%)
  adding: results/reconstruction_23.png (deflated 3%)
  adding: results/reconstruction_15.png (deflated 2%)
  adding: results/reconstruction_26.png (deflated 2%)
  adding: results/sample_22.png (deflated 4%)
  adding: results/sample_18.png (deflated 4%)
  adding: results/reconstruction_14.png (deflated 2%)
  adding: results/sample_25.png (deflated 4%)
  adding: results/sample_3.png (deflated 3%)
  adding: results/reconstruction_20.png (deflated 2%)
  adding: results/sample_13.png (deflated 4%)
  adding: results/reconstruction_1.png (deflated 1%)
  adding: results/sample_20.png (deflated 4%)
  adding: results/sample_26.png (deflated 4%)
  adding: results/sample_8.png (deflated 4%)
  adding: results/reco

In [25]:
!mv results.zip /content/drive/MyDrive/