In [2]:
#This code produces a representation for each digit in 0, 1, 2, ..., 9 in some other lower
#dimensional space, where it must draw from Gaussian distributions. If you run the code to the
#end you will see that the recreated digits are quite blurry.


import numpy as np
import torch
from torch.autograd import Variable
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import random
from PIL import Image

import torchvision
import torchvision.datasets as datasets
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


9920512it [00:06, 1640506.51it/s]                                                                                      


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


32768it [00:00, 111931.82it/s]                                                                                         


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


1654784it [00:01, 1050578.23it/s]                                                                                      


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


8192it [00:00, 41080.31it/s]                                                                                           


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [3]:
def img_to_vector(imgs):
    data = [np.asarray(img, dtype="int32").reshape(28*28, 1) for img in imgs]
    return np.array(data)

In [4]:
mnist_trainset_images = [mnist_trainset[i][0] for i in range(len(mnist_trainset))]
mnist_trainset_digits = [mnist_trainset[i][1] for i in range(len(mnist_trainset))]

In [5]:
mnist_trainset_arrays = img_to_vector(mnist_trainset_images)*1/256

In [6]:
data = mnist_trainset_arrays

In [7]:
input_size = 28*28
encoded_size=100

class VAE(nn.Module):
    
    def __init__(self):
        super(VAE, self).__init__()
        self.e = nn.Linear(input_size, 100)
        self.e_mean = nn.Linear(100, encoded_size)
        self.e_logvar = nn.Linear(100, encoded_size)
        self.d1 = nn.Linear(encoded_size, 100)
        self.d2 = nn.Linear(100, input_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def en(self, x):
        mean = self.e_mean(self.relu(self.e(x)))
        logvar = self.e_logvar(self.relu(self.e(x)))
        return mean, logvar
    
    def sample(self, mean, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mean + eps*std
    
    def de(self, z):
        output = self.sigmoid(self.d2(self.relu(self.d1(z))))
        return output
    
    def forward(self, x):
        mu, logvar = self.en(x)
        z = self.sample(mu, logvar)
        result = self.de(z)
        return result, mu, logvar

In [8]:
model = VAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [89]:
def lossf(newx, x, mu, logvar):
    BCE = F.binary_cross_entropy(newx, x.view(-1, input_size), reduction = 'sum')
    KLD = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train(epochs):
    for epoch in range(epochs):
        t=0
        model.train()
        train_loss = 0
        for point in data:
            t+=1
            optimizer.zero_grad()
            point = Variable(torch.tensor(point.T).float())
            newpoint, mu, logvar = model(point)
            loss = lossf(newpoint, point.detach(), mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            if t % 100 == 0:
                print('{}/60000 done for epoch {}'.format(t, epoch))
        print('EPOCH: {} LOSS {}'.format(epoch, train_loss))

In [91]:
train(1)

100/60000 done for epoch 0
200/60000 done for epoch 0
300/60000 done for epoch 0
400/60000 done for epoch 0
500/60000 done for epoch 0
600/60000 done for epoch 0
700/60000 done for epoch 0
800/60000 done for epoch 0
900/60000 done for epoch 0
1000/60000 done for epoch 0
1100/60000 done for epoch 0
1200/60000 done for epoch 0
1300/60000 done for epoch 0
1400/60000 done for epoch 0
1500/60000 done for epoch 0
1600/60000 done for epoch 0
1700/60000 done for epoch 0
1800/60000 done for epoch 0
1900/60000 done for epoch 0
2000/60000 done for epoch 0
2100/60000 done for epoch 0
2200/60000 done for epoch 0
2300/60000 done for epoch 0
2400/60000 done for epoch 0
2500/60000 done for epoch 0
2600/60000 done for epoch 0
2700/60000 done for epoch 0
2800/60000 done for epoch 0
2900/60000 done for epoch 0
3000/60000 done for epoch 0
3100/60000 done for epoch 0
3200/60000 done for epoch 0
3300/60000 done for epoch 0
3400/60000 done for epoch 0
3500/60000 done for epoch 0
3600/60000 done for epoch 0
3

In [10]:
def get_image(vector):
    img_array = (vector.reshape(28, 28)*256).astype(int)
    img = Image.fromarray(img_array)
    img.show()

In [11]:
im = torch.tensor(mnist_trainset_arrays[1].T).float()
newdigit = model(im)
newdigit_np = newdigit[0].detach().numpy()
get_image(newdigit_np)