In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from model import VAE, loss_function

# Create datasets and data loaders

In [21]:
batch_size = 128
num_epochs = 10
num_images = 8
device = torch.device('cuda:0')

In [22]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True, 
                                                          transform=transforms.ToTensor()),
                                                          batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
                                          batch_size=batch_size, shuffle=False)

# Create training loop


In [23]:
def train(net, optimizer, num_epoch, train_loader):
    net.train()
    train_loss = 0
    for idx, (data, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.to(device)
        output, mu, logvar = net(data.view(-1, 784))
        loss = loss_function(data, output, mu, logvar)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() 
    
        if idx % 100 == 0:
            print (f'{loss.item() / data.size(0)} Epoch {num_epoch} {idx} / {len(train_loader)}')
    print ('-----------------------------------------------------------------------------------------')
    print (f'Train Epoch {num_epoch} loss: {train_loss / len(train_loader.dataset)}')
    print('-----------------------------------------------------------------------------------------')

In [24]:
def test(net, num_epoch, test_loader):
    net.eval()
    test_loss = 0 
    for idx, (data, labels) in enumerate(test_loader):
        data = data.to(device)
        output, mu, logvar = net(data.view(-1, 784))
        loss = loss_function(data, output, mu, logvar)
        test_loss += loss.item()
        if idx == 0:
            print (f'{loss.item() / data.size(0)} Epoch {num_epoch} {idx} / {len(test_loader)}')
            image_and_pred = torch.cat([data[:num_images], output.view(batch_size, 1, 28, 28)[:num_images]])
            save_image(image_and_pred.cpu(),
                         'results/reconstruction_' + str(num_epoch) + '.png', nrow=batch_size)
    print ('-----------------------------------------------------------------------------------------')
    print (f'Test Epoch {num_epoch} loss: {test_loss / len(test_loader.dataset)}')
    print('-----------------------------------------------------------------------------------------')

# Run train and test loops

In [25]:
net = VAE().to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    train(net, optimizer, epoch, train_loader)
    test(net, epoch, test_loader)

> [0;32m/home/siddhesh1793/code/algos_from_scratch/vae/model.py[0m(40)[0;36mloss_function[0;34m()[0m
[0;32m     38 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m    [0mkl_divergence[0m [0;34m=[0m [0;34m-[0m[0;36m0.5[0m [0;34m*[0m [0mtorch[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;36m1[0m [0;34m+[0m [0mlogvar[0m [0;34m-[0m [0mmu[0m[0;34m.[0m[0mpow[0m[0;34m([0m[0;36m2[0m[0;34m)[0m [0;34m-[0m [0mlogvar[0m[0;34m.[0m[0mexp[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m[0;34m[0m[0m
[0m[0;32m     42 [0;31m    [0;32mreturn[0m [0mloss[0m [0;34m+[0m [0mkl_divergence[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> loss.shape
torch.Size([128, 784])
ipdb> loss.sum()
tensor(69753.6562, device='cuda:0', grad_fn=<SumBackward0>)
ipdb> loss.mean()

BdbQuit: 