In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib
from numpy import linalg as LA
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import torch.optim as optim
from PIL import Image
import os

In [2]:
class gaussian_sample_layer(nn.Module):

    def __init__(self, latent_dim):
        super(gaussian_sample_layer, self).__init__()
        self.latent_dim = latent_dim
        self.L = 1
        
    def forward(self, mu, sigma):
        epsilon_dist = torch.distributions.MultivariateNormal(torch.zeros(self.latent_dim),torch.eye(self.latent_dim))
        epsilon = epsilon_dist.sample((self.L,))
        epsilon = torch.sum(epsilon, dim=0) / self.L
        a = mu + epsilon * sigma
        return a

In [3]:
class variational_autoencoder(torch.nn.Module):

    def __init__(self):
        super(variational_autoencoder, self).__init__()
        input_size = 784
        output_size = 784
        self.latent_dim = 16
        self.mlp1 = nn.Linear(input_size, 128)
        self.mu = nn.Linear(128, self.latent_dim)
        self.sigma = nn.Linear(128, self.latent_dim)
        self.gaussian = gaussian_sample_layer(self.latent_dim)
        self.mlp4 = nn.Linear(self.latent_dim, 128)
        self.out = nn.Linear(128, output_size)
        self.batch_size = 8

    def forward(self, x):
        h1 = F.sigmoid(self.mlp1(x))
        mu = F.sigmoid(self.mu(h1))
        sigma = F.sigmoid(self.sigma(h1))
        z = self.gaussian(mu, sigma)
        h4 = F.tanh(self.mlp4(z))
        y_hat = F.relu(self.out(h4))
        return y_hat, mu, sigma
    
    def loss(self, x, y, beta=0.0001):
        y_hat, mu, sigma = self.forward(x)
        c = nn.MSELoss()
        l = c(y_hat, y) - beta * 1/2 * torch.sum(1 + torch.log(sigma**2) - mu**2 - sigma**2)
        return l
    
    def decoder(self, x):
        f = F.relu(self.out(F.tanh(self.mlp4(x))))
        return f

In [4]:
vae = variational_autoencoder()
optimizer = optim.Adam(vae.parameters(), lr=3e-4)

In [5]:
batch_size = 8
train = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
trainset = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)

In [6]:
for epoch in range(100):
    cnt = 0
    l = 0
    for data in trainset:
        data = data[0].squeeze()
        x = torch.reshape(data, (batch_size, 784))
        y = x.clone()
        optimizer.zero_grad()
        y_hat, _, _ = vae(x)
        loss = vae.loss(x, y)
        loss.backward()
        optimizer.step()
        l += loss
        cnt += 1
    print(l)
    '''if cnt % 600 == 0:
            cnt = 0
            print(l / 600)
            l = 0'''



tensor(473.1535, grad_fn=<AddBackward0>)
tensor(416.5661, grad_fn=<AddBackward0>)
tensor(402.3413, grad_fn=<AddBackward0>)
tensor(395.5591, grad_fn=<AddBackward0>)
tensor(388.6825, grad_fn=<AddBackward0>)
tensor(382.5290, grad_fn=<AddBackward0>)
tensor(376.9836, grad_fn=<AddBackward0>)
tensor(371.5956, grad_fn=<AddBackward0>)
tensor(366.5996, grad_fn=<AddBackward0>)
tensor(362.1712, grad_fn=<AddBackward0>)
tensor(358.1848, grad_fn=<AddBackward0>)
tensor(355.2062, grad_fn=<AddBackward0>)
tensor(352.0761, grad_fn=<AddBackward0>)
tensor(349.6087, grad_fn=<AddBackward0>)
tensor(347.6811, grad_fn=<AddBackward0>)
tensor(345.6353, grad_fn=<AddBackward0>)
tensor(344.0138, grad_fn=<AddBackward0>)
tensor(342.5093, grad_fn=<AddBackward0>)
tensor(340.7488, grad_fn=<AddBackward0>)
tensor(339.8383, grad_fn=<AddBackward0>)
tensor(338.6263, grad_fn=<AddBackward0>)
tensor(337.5218, grad_fn=<AddBackward0>)
tensor(336.7500, grad_fn=<AddBackward0>)
tensor(335.5693, grad_fn=<AddBackward0>)
tensor(334.8258,

In [7]:
test = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
testset = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)

In [8]:
results = []
cnt = 0
os.makedirs("img", exist_ok=True)
for test in testset:
    data = test[0].squeeze()
    x = torch.reshape(data, (batch_size, 784))
    out = vae(x)[0].detach().numpy()
    cnt += 1
    for i in range(data.shape[0]):
        plt.imsave('./img/' + str(cnt) + str(i) + 'org.png', data[i], cmap='gray')
        plt.imsave('./img/' + str(cnt) + str(i) + 'recovered.png', out[i].reshape(28, 28), cmap='gray')

In [9]:
weights = []
for w in vae.parameters():
    weights.append(w)
weights = np.array(weights)

In [37]:
os.makedirs("generated", exist_ok=True)
for i in range(1000):
    dist = torch.distributions.MultivariateNormal(torch.zeros(16),torch.eye(16))
    z = dist.sample()
    plt.imsave('./generated/' + str(i)+'.png', vae.decoder(z).detach().numpy().reshape(28, 28), cmap='gray')