# AutoEncoder (3) Variational AutoEncoder

- [L1aoXingyu@github](https://github.com/L1aoXingyu)の，[Variational AutoEncoderの実装](https://github.com/L1aoXingyu/pytorch-beginner/blob/master/08-AutoEncoder/Variational_autoencoder.py)の，`nn.Linear`に基づいているコードをJupyterNotebook用に適時修正した．  
- しかし，`nn.Linear`をシンプルに`nn.Conv2D`に変えようと試みたが，学習が上手く行かない．そこで，[3ammor@github](https://github.com/3ammor)の[Variational-Autoencoder-pytorch](https://github.com/3ammor/Variational-Autoencoder-pytorch)を元に変えようとした.  

In [1]:
import torch
import os

## create folder in advance
folder = './data/VAE_img'
if not os.path.isdir(folder):
    os.mkdir(folder)

## set folder in advance
model_path = './data/VAE_autoencoder.pth'

## set some constants for learning
num_epochs = 50
batch_size = 128
learning_rate = 1e-3

## (1) Prepare dataset: MNIST hand-written digits

Almost same with the Simple encoder.

In [2]:
from torchvision.datasets import MNIST
from torchvision import transforms

## image to tensor
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

## tensor to image
def to_img(x):
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

## dataset with conversion
dataset_train = MNIST('./data', train=True, download=True, transform=img_transform)

In [3]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

## (2) Prepare model: Variational AutoEncoder

This network is referred from [aidiary@github](https://github.com/aidiary/conv-vae/blob/master/model.py) ...

In [4]:
from torch import nn
from torch.autograd import Variable

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(16)

        self.fc1 = nn.Linear(7 * 7 * 16, 512)
        self.fc_bn1 = nn.BatchNorm1d(512)
        self.fc21 = nn.Linear(512, 512)
        self.fc22 = nn.Linear(512, 512)

        # Decoder
        self.fc3 = nn.Linear(512, 512)
        self.fc_bn3 = nn.BatchNorm1d(512)
        self.fc4 = nn.Linear(512, 7 * 7 * 16)
        self.fc_bn4 = nn.BatchNorm1d(7 * 7 * 16)

        self.conv5 = nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(32)
        self.conv6 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(32)
        self.conv7 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn7 = nn.BatchNorm2d(16)
        self.conv8 = nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=1, bias=False)

        self.relu = nn.ReLU()

    def encode(self, x):
        conv1 = self.relu(self.bn1(self.conv1(x)))
        conv2 = self.relu(self.bn2(self.conv2(conv1)))
        conv3 = self.relu(self.bn3(self.conv3(conv2)))
        conv4 = self.relu(self.bn4(self.conv4(conv3))).view(-1, 7 * 7 * 16)

        fc1 = self.relu(self.fc_bn1(self.fc1(conv4)))
        return self.fc21(fc1), self.fc22(fc1)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        fc3 = self.relu(self.fc_bn3(self.fc3(z)))
        fc4 = self.relu(self.fc_bn4(self.fc4(fc3))).view(-1, 16, 7, 7)

        conv5 = self.relu(self.bn5(self.conv5(fc4)))
        conv6 = self.relu(self.bn6(self.conv6(conv5)))
        conv7 = self.relu(self.bn7(self.conv7(conv6)))
        return self.conv8(conv7).view(-1, 1, 28, 28)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [5]:
import torch.nn.functional as F
from torch import optim

## instantiate model
model = VAE()
if torch.cuda.is_available():
    model.cuda() ## send to GPU

reconstruction_function = nn.MSELoss(size_average=False)

def loss_function(recon_x, x, mu, logvar):
    
    ## MeanSquaredError (reconstruction error)
    MSE = reconstruction_function(recon_x, x)

    # KLD: KL divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return MSE + KLD

optimizer = optim.Adam(model.parameters(), lr=1e-3)



In [6]:
from torchsummary import summary

## https://github.com/sksq96/pytorch-summary
#summary(model.encoder, (1, 28, 28))
summary(model, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 28, 28]             144
       BatchNorm2d-2           [-1, 16, 28, 28]              32
              ReLU-3           [-1, 16, 28, 28]               0
            Conv2d-4           [-1, 32, 14, 14]           4,608
       BatchNorm2d-5           [-1, 32, 14, 14]              64
              ReLU-6           [-1, 32, 14, 14]               0
            Conv2d-7           [-1, 32, 14, 14]           9,216
       BatchNorm2d-8           [-1, 32, 14, 14]              64
              ReLU-9           [-1, 32, 14, 14]               0
           Conv2d-10             [-1, 16, 7, 7]           4,608
      BatchNorm2d-11             [-1, 16, 7, 7]              32
             ReLU-12             [-1, 16, 7, 7]               0
           Linear-13                  [-1, 512]         401,920
      BatchNorm1d-14                  [

## (3) Training model

In [7]:
from torchvision.utils import save_image

## training
model.train()

for epoch in range(num_epochs):
    train_loss = 0
    
    for batch_idx, data in enumerate(dataloader):
        img, _ = data
        if torch.cuda.is_available():
            img = img.cuda() ## send to GPU

        optimizer.zero_grad()
        
        ## feed-forward
        recon_batch, mu, logvar = model(img)
        loss = loss_function(recon_batch, img, mu, logvar)
        train_loss += loss.item()
        
        ## backprop
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(img),
                len(dataloader.dataset), 100. * batch_idx / len(dataloader),
                loss.item() / len(img)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(dataloader.dataset)))
    if epoch % 10 == 0:
        save = to_img(recon_batch.cpu().data)
        save_image(save, '{}/image_{}.png'.format(folder, epoch))

====> Epoch: 0 Average loss: 224.4157
====> Epoch: 1 Average loss: 118.4619
====> Epoch: 2 Average loss: 105.5378
====> Epoch: 3 Average loss: 94.0490
====> Epoch: 4 Average loss: 85.0625
====> Epoch: 5 Average loss: 78.4160
====> Epoch: 6 Average loss: 74.3653
====> Epoch: 7 Average loss: 71.9223
====> Epoch: 8 Average loss: 70.2763
====> Epoch: 9 Average loss: 69.1350
====> Epoch: 10 Average loss: 67.8062
====> Epoch: 11 Average loss: 66.7986
====> Epoch: 12 Average loss: 66.0789
====> Epoch: 13 Average loss: 65.4828
====> Epoch: 14 Average loss: 64.6673
====> Epoch: 15 Average loss: 64.1934
====> Epoch: 16 Average loss: 63.7271
====> Epoch: 17 Average loss: 63.2065
====> Epoch: 18 Average loss: 62.7676
====> Epoch: 19 Average loss: 62.3770
====> Epoch: 20 Average loss: 62.0300
====> Epoch: 21 Average loss: 61.6648
====> Epoch: 22 Average loss: 61.4139
====> Epoch: 23 Average loss: 61.0901
====> Epoch: 24 Average loss: 60.6967
====> Epoch: 25 Average loss: 60.5316
====> Epoch: 26 Ave

====> Epoch: 28 Average loss: 59.9355
====> Epoch: 29 Average loss: 59.7222
====> Epoch: 30 Average loss: 59.4966
====> Epoch: 31 Average loss: 59.3809
====> Epoch: 32 Average loss: 59.1514
====> Epoch: 33 Average loss: 59.0761
====> Epoch: 34 Average loss: 58.7877
====> Epoch: 35 Average loss: 58.7058
====> Epoch: 36 Average loss: 58.6047
====> Epoch: 37 Average loss: 58.4194
====> Epoch: 38 Average loss: 58.2463
====> Epoch: 39 Average loss: 58.1611
====> Epoch: 40 Average loss: 58.0716
====> Epoch: 41 Average loss: 57.8433
====> Epoch: 42 Average loss: 57.7716
====> Epoch: 43 Average loss: 57.8330
====> Epoch: 44 Average loss: 57.5282
====> Epoch: 45 Average loss: 57.5088
====> Epoch: 46 Average loss: 57.4192
====> Epoch: 47 Average loss: 57.2649
====> Epoch: 48 Average loss: 57.1476
====> Epoch: 49 Average loss: 57.0791


In [8]:
## save trained model
torch.save(model.state_dict(), model_path)

## (4) Testing model

In [9]:
## load trained model
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint)
model.eval() ## switch to "evaluate" mode

VAE(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc_bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc21): Linear(in_features=512, out_features=512, bias=True)
  (fc22): Linear(in_features=512, out_features=512, bias=True)
  (

In [10]:
## [TODO] visualize the result ...

(end)