In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
import torchvision
from torchvision import datasets
from torchvision import transforms

train_data = datasets.MNIST(root='./data_mnist/train/', train=True, transform=transforms.ToTensor(), download=False)
test_data = datasets.MNIST(root='./mnist_data/test/', train=False, transform=transforms.ToTensor(), download=False)

tr = torch.utils.data.DataLoader(dataset=train_data, batch_size=8, shuffle=True)
tst = torch.utils.data.DataLoader(dataset=test_data, batch_size=8, shuffle=False)

In [3]:
from model import VAE,VAEloss

layers = [784,640,320,80,20]
vae = VAE(layers)
if torch.cuda.is_available():
    vae.cuda()
vae.describe()

linear_enc1 (784, 640)
linear_enc2 (640, 320)
linear_enc3 (320, 80)
linear_enc5_mean (80, 20)
linear_enc5_std (80, 20)
linear_dec4 (20, 80)
linear_dec3 (80, 320)
linear_dec2 (320, 640)
linear_dec1 (640, 784)


In [4]:
import torch.optim as optim

optimizer = optim.Adam(vae.parameters())
criterion = VAEloss

In [5]:
loss_history = []
# acc_history = []
for epoch in range(50):
    train_loss = 0.0
    for i,data in enumerate(tr):
        x, y = data
        if torch.cuda.is_available():
            x, y = x.cuda(), y.cuda()
        optimizer.zero_grad()
        y_pred, mean, logstd = vae(x.view(-1,784))
        loss = criterion(y_pred, x.view(-1,784), mean, logstd)
        x = loss.item()
        train_loss += x
        if (i%40==0):
            print('(%d, %5d) item-loss: %.9f'%(epoch + 1, i + 1, x))
        loss.backward()
        optimizer.step()   
    print('[%d] epoch-loss: %.9f'%(epoch + 1, train_loss))
    loss_history.append(train_loss)
#     acc_history.append(get_accuracy())
print('Finished Training')

(1,     1) item-loss: 4354.654785156
(1,    41) item-loss: 1981.650512695
(1,    81) item-loss: 1543.547241211
(1,   121) item-loss: 1679.371215820
(1,   161) item-loss: 1789.319946289
(1,   201) item-loss: 1644.952392578
(1,   241) item-loss: 1349.015869141
(1,   281) item-loss: 1624.691894531
(1,   321) item-loss: 1610.260986328
(1,   361) item-loss: 1578.882934570
(1,   401) item-loss: 1445.746948242
(1,   441) item-loss: 1583.377441406
(1,   481) item-loss: 1613.934448242
(1,   521) item-loss: 1515.623168945
(1,   561) item-loss: 1736.087280273
(1,   601) item-loss: 1526.037109375
(1,   641) item-loss: 1449.536987305
(1,   681) item-loss: 1622.219604492
(1,   721) item-loss: 1352.725341797
(1,   761) item-loss: 1303.525634766
(1,   801) item-loss: 1571.272216797
(1,   841) item-loss: 1308.579101562
(1,   881) item-loss: 1347.767700195
(1,   921) item-loss: 1602.302246094
(1,   961) item-loss: 1291.191772461
(1,  1001) item-loss: 1353.031616211
(1,  1041) item-loss: 1361.282348633
(

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
plt.plot(loss_history)

In [None]:
##saving the model
PATH = 'vae_mnist.pth'
torch.save(vae.state_dict(), PATH)
##loading the model
model = vae(layers)
model.load_state_dict(torch.load(PATH))
model.eval()

In [None]:
sample = Variable(torch.randn(128, 20))
recon_x = vae.decode(sample)

save_image(recon_x.view(recon_x.size(0), 1, 28, 28).data.cpu(), 'sample_image.png')
Image('sample_image.png')