In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
batch_size = 128
learning_rate = 0.0005
num_epoch = 20
hidden_size = 100

In [3]:
mnist_train = dataset.MNIST("./data_CVAE", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_test = dataset.MNIST("./data_CVAE", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data_CVAE/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./data_CVAE/MNIST/raw/train-images-idx3-ubyte.gz to ./data_CVAE/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data_CVAE/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting ./data_CVAE/MNIST/raw/train-labels-idx1-ubyte.gz to ./data_CVAE/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data_CVAE/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ./data_CVAE/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data_CVAE/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data_CVAE/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./data_CVAE/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data_CVAE/MNIST/raw
Processing...
Done!


In [4]:
train_loader = torch.utils. data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
test_loader = torch.utils. data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)

In [6]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Sequential(
                        nn.Conv2d(1, 8, 3, padding=1),# batch x 8 * 28 * 28
                        nn.BatchNorm2d(8),
                        nn.ReLU(),
                        nn.MaxPool2d(2,2),
                        nn.Conv2d(8, 16, 3, padding=1),#batch x 16 * 14 * 14
                        nn.BatchNorm2d(16),
                        nn.ReLU(),
                        nn.MaxPool2d(2,2),
                        nn.Conv2d(16, 32, 3, padding=1),#batch x 32 * 7 * 7
                        nn.ReLU(),
        )
        self.fc2_1 = nn.Sequential(
                            nn.Linear(32*7*7, 800),
                            nn.Linear(800, hidden_size),
        )
        self.fc2_2 = nn.Sequential(
                            nn.Linear(32*7*7, 800),
                            nn.Linear(800, hidden_size),
        )
        self.relu = nn.ReLU()
        
    def encode(self, x):
        out = self.fc1(x)
        out = out.view(batch_size, -1)
        out = self.relu(out)
        mu = self.fc2_1(out)
        log_var = self.fc2_2(out)
        
        return mu, log_var
    
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        return eps.mul(std).add_(mu)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        reparam = self.reparameterize(mu, logvar)
        return mu, logvar, reparam
    
encoder = Encoder()

In [9]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Sequential(
                        nn.Linear(hidden_size, 800),
                        nn.BatchNorm1d(800),
                        nn.ReLU(),
                        nn.Linear(800, 1568),
                        nn.ReLU(),
        )
        self.fc2 = nn.Sequential(
                        nn.ConvTranspose2d(32, 16, 3, 2, 1, 1),
                        nn.ReLU(),
                        nn.BatchNorm2d(16),
                        nn.ConvTranspose2d(16, 8, 3, 2, 1, 1),
                        nn.ReLU(),
                        nn.BatchNorm2d(8),
                        nn.ConvTranspose2d(8, 1, 3, 1, 1),
                        nn.BatchNorm2d(1),
        )
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = out.view(batch_size, 32, 7, 7)
        out = self.fc2(out)
        out = self.sigmoid(out)
        out = out.view(batch_size, 28, 28, 1)
        
        return out 
    
decoder = Decoder()

In [10]:
reconstruction_function = nn.BCELoss(size_average=False)

def loss_function(recon_x, x, mu, logvar):
    BCE = reconstruction_function(recon_x, x)
    
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    
    return BCE + KLD

parameters = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(parameters, lr=learning_rate)



In [None]:
try:
    encoder, decoder = torch.load('./model_CVAE/conv_variational_autoencoder.pkl')
    print("\n----------model restored----------\n")
except:
    print("\n----------model not restored----------\n")
    pass

for i in range(num_epoch):
    for j, [image, label] in enumerate(train_loader):
        optimizer.zero_grad()
        
        mu, log_var, reparam = encoder(image)
        output = decoder(reparam)
        
        loss = loss_function(output, image, mu, log_var)
        loss.backward()
        optimizer.step()
        
        if j % 10 == 0:
            torch.save([encoder, decoder], './model_CVAE/conv_variational_autoencoder.pkl')
            print(loss)


----------model not restored----------



  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


tensor(80250.8359, grad_fn=<AddBackward0>)
tensor(66831.0312, grad_fn=<AddBackward0>)
tensor(63816.9102, grad_fn=<AddBackward0>)
tensor(62233.7266, grad_fn=<AddBackward0>)
tensor(61621.0352, grad_fn=<AddBackward0>)
tensor(60327.5391, grad_fn=<AddBackward0>)
tensor(59597.8203, grad_fn=<AddBackward0>)
tensor(58976.2656, grad_fn=<AddBackward0>)
tensor(58019.7656, grad_fn=<AddBackward0>)
tensor(57510.2344, grad_fn=<AddBackward0>)
tensor(57318.5469, grad_fn=<AddBackward0>)
tensor(56948.6055, grad_fn=<AddBackward0>)
tensor(55920.7812, grad_fn=<AddBackward0>)
tensor(56729.0078, grad_fn=<AddBackward0>)
tensor(54851.8984, grad_fn=<AddBackward0>)
tensor(55863.8047, grad_fn=<AddBackward0>)
tensor(55040.3242, grad_fn=<AddBackward0>)
tensor(54555.1016, grad_fn=<AddBackward0>)
tensor(54736.7656, grad_fn=<AddBackward0>)
tensor(53938.2031, grad_fn=<AddBackward0>)
tensor(53846.7852, grad_fn=<AddBackward0>)
tensor(53883.1328, grad_fn=<AddBackward0>)
tensor(53090.6172, grad_fn=<AddBackward0>)
tensor(5333

In [None]:
out_img = torch.squeeze(output.cpu().data)
print(out_img.size())

plt.imshow(out_img[0].numpy(), cmap='gray')
plt.show()