In [276]:
import os
import numpy as np
import random
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets,transforms
from torchvision.utils import save_image


In [277]:
torch.cuda.empty_cache()

In [278]:
image_path = './images'
channels = 1                    # MNIST has only 1

n_epochs = 30
batch_size = 128
lr = 1e-3

img_size = 28

In [279]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [280]:
transform = transforms.Compose([
            transforms.ToTensor()
            ])

train = datasets.MNIST(root='./data/',train=True,transform=transform,download=True)
test = datasets.MNIST(root='./data/',train=False,transform=transform,download=True)

train_dataloader = torch.utils.data.DataLoader(
            train,
            batch_size=batch_size,
            shuffle=True,

)

test_dataloader = torch.utils.data.DataLoader(
            test,
            batch_size=batch_size,
            shuffle=False,
)

In [281]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        
        self.hidden_dim = [input_dim] + hidden_dim
        self.encoder = nn.ModuleList([nn.Linear(self.hidden_dim[idx], self.hidden_dim[idx+1]) 
                                      for idx in range(len(self.hidden_dim)-1)])
        self.mu = nn.Linear(self.hidden_dim[-1], latent_dim)
        self.logvar = nn.Linear(self.hidden_dim[-1], latent_dim)
        self.decoder = nn.ModuleList([nn.Linear(latent_dim, self.hidden_dim[-1])] + [nn.Linear(self.hidden_dim[idx], self.hidden_dim[idx-1]) 
                                                                                for idx in range(len(self.hidden_dim)-1, 0, -1)])
        
        self.init_weights()
        
    def init_weights(self):
        for layer in self.encoder:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
        for layer in self.decoder:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
        nn.init.xavier_uniform_(self.mu.weight)
        nn.init.zeros_(self.mu.bias)
        
    def reparameterization(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        
        return mu + eps * std
    
    def forward(self, x):
        for layer in self.encoder:
            x = F.relu(layer(x))
            
        mu = F.relu(self.mu(x))
        logvar = F.relu(self.logvar(x))
        z = self.reparameterization(mu, logvar)
        
        for idx in range(len(self.decoder)):
            if idx == len(self.decoder) -1: 
                z = F.sigmoid(self.decoder[idx](z))
            else:
                z = F.relu(self.decoder[idx](z))
        
        return z, mu, logvar

In [282]:
model = VAE(img_size**2, [256], 128)

In [283]:
print(model)

VAE(
  (encoder): ModuleList(
    (0): Linear(in_features=784, out_features=256, bias=True)
  )
  (mu): Linear(in_features=256, out_features=128, bias=True)
  (logvar): Linear(in_features=256, out_features=128, bias=True)
  (decoder): ModuleList(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=784, bias=True)
  )
)


In [284]:
model = model.to(device)

In [285]:
class VAE_loss(nn.Module):
    def __init__(self):
        super(VAE_loss, self).__init__()
    
    def forward(self, x, x_hat, mu, logvar):
        reconst_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
        kl_div = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)
        
        return reconst_loss + kl_div
    
criterion = VAE_loss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, eps=1e-4)

In [286]:
for epoch in range(n_epochs):
    train_loss = 0
    for i, (x, _) in enumerate(train_dataloader):
        # forward
        x = x.view(-1, img_size**2)
        x = x.to(device)
        pred, mu, logvar = model(x)
        reconst_loss = F.binary_cross_entropy(pred, x, reduction='sum')
        kl_divergence = - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        loss = reconst_loss + kl_divergence
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    train_loss /= len(train_dataloader)

    print(f'Epoch {epoch+1}/{n_epochs}, train loss: {train_loss:.4f}')

Epoch 1/30, train loss: 24305.8743
Epoch 2/30, train loss: 20219.7461
Epoch 3/30, train loss: 19470.4025
Epoch 4/30, train loss: 19068.3516
Epoch 5/30, train loss: 18811.7356
Epoch 6/30, train loss: 18624.4605
Epoch 7/30, train loss: 18489.3783
Epoch 8/30, train loss: 18368.9338
Epoch 9/30, train loss: 18275.9860
Epoch 10/30, train loss: 18222.4914
Epoch 11/30, train loss: 18127.2220
Epoch 12/30, train loss: 18063.6287
Epoch 13/30, train loss: 18012.4368
Epoch 14/30, train loss: 17968.5162
Epoch 15/30, train loss: 17907.1699
Epoch 16/30, train loss: 17888.2026
Epoch 17/30, train loss: 17845.9119
Epoch 18/30, train loss: 17811.1885
Epoch 19/30, train loss: 17773.1718
Epoch 20/30, train loss: 17740.0160
Epoch 21/30, train loss: 17710.7172
Epoch 22/30, train loss: 17686.4788
Epoch 23/30, train loss: 17667.2347
Epoch 24/30, train loss: 17635.2107
Epoch 25/30, train loss: 17626.6690
Epoch 26/30, train loss: 17577.5318
Epoch 27/30, train loss: 17566.2183
Epoch 28/30, train loss: 17550.2618
E