In [2]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [8]:
"""
parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default : 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default : 10)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default : 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
"""

"\nparser = argparse.ArgumentParser(description='VAE MNIST Example')\nparser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default : 128)')\nparser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default : 10)')\nparser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training')\nparser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default : 1)')\nparser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')\nargs = parser.parse_args()\nargs.cuda = not args.no_cuda and torch.cuda.is_available()\n"

In [24]:
import easydict

args = easydict.EasyDict({
    "batch_size" : 128,
    "epochs" : 10,
    "no_cuda" : False,
    "seed" : 1,
    "log_interval" : 10
})
args.cuda = not args.no_cuda and torch.cuda.is_available()

In [25]:
torch.manual_seed(args.seed)
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

In [28]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', 
                   train=True,
                   download=True,
                   transform=transforms.ToTensor()),
    batch_size = args.batch_size,
    shuffle = True,
    **kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


In [30]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data',
                   train=False,
                   transform=transforms.ToTensor()),
    batch_size=args.batch_size,
    shuffle=True,
    **kwargs)            

##### def encode(self, x)
784 차원 벡터 입력 벡터 -> 400 차원 히든 벡터 -> 액티베이션 함수 ( Relu )  
가우시안 분포에 대한 평균(mean)과 분산(variance)을 20개 내보낸다  
Mean은 self.fc21(h1)이며 linear 연산을 통한 output 이다  
Variance의 경우 항상 0보다 크거나 같아야한다. Variance의 경우 항상 0보다 크거나 같아야하는데  
linear 연산을 한 self.fc22(h1)의 경우 -값이 될수 있다. 따라서 variance가 아닌 log variance라고 본다
  
##### reparameterize(self, mu, logvar)
mean과 log variance를 구했다면 reparameterization으로 latent vector z를 샘플링 할 수 있다  
Noise로부터 eps를 구하고 eps와 std를 곱하고 mean을 더해준다
  
##### decode(self, z)
z를 구하고나면 이 latent 값으로부터 decoder를 통해 data에 대한 베르누이 분포를 출력할 수 있다  
베르누이 분포는 0에서 1 사이이므로 sigmoid 함수를 output layer의 activation 함수로 사용한다
  
##### forward
x.view(-1, 784)는 이미지를 784 차원의 벡터로 만드는 부분이다


In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
        
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + (eps * std)
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [33]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [34]:
# Reconstruction + KL Divergence Losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    # 0.5 * sum( 1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum( 1 + logvar - mu.pow(2) - logvar.exp() )
    
    return BCE + KLD

##### train( epoch )
model.train()으로 현재 학습할 것이라는 선언을 한다  
train_loader로 mini_batch를 추출한다  
model에 data를 입력으로 넣어 출력을 받는다  
출력으로 loss function 값을 계산하고 loss.backward()로 back-propagation으로 각 parameter의 gradient 값을 구한다
optimizer( Adam Optmizer )를 통해 parameter를 update 한다

In [35]:
def train( epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), 
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
            
    print('====> Epoch: {} Average Loss: {:.4f}'.format(
            epoch, train_loss / len(train_loader.dataset)))
    

##### test ( epoch ) - 학습 과정을 evaluate
model.eval()으로 평가 중이라는 것을 선언 한다  
test dataset에 대해 reconstruction을 출력한다  
loss 함수 값을 출력해서 학습이 어떻게 진행되고 있는지 평가한다  
평가 과정마다 생성된 하나의 sample을 저장한다

In [40]:
def test( epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device) 
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat( [ data[:n], recon_batch.view(args.batch_size, 1, 28, 28)[:n] ] )
                save_image(comparison.cpu(),
                          'results/reconstruction_' + str(epoch) + '.png', nrow=n)
    
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss : {:.4f}'.format(test_loss))

##### main 루프
epoch 마다 train을 한다  
test를 한 다음에 64개의 sample 이미지를 생성한다  
latent는 임의로 20개를 normal distribution에서 sampling 한다  
sampling 된 latent variable을 Decoder에 통과시키면 Decoder는 이미지를 생성한다

In [42]:
if __name__ == "__main__":
    for epoch in range(1, args.epochs + 1):
        train( epoch )
        test( epoch )
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28), 'results/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average Loss: 114.4909
====> Test set loss : 111.7891
====> Epoch: 2 Average Loss: 111.4638
====> Test set loss : 109.4725
====> Epoch: 3 Average Loss: 109.7196
====> Test set loss : 108.2173
====> Epoch: 4 Average Loss: 108.5881
====> Test set loss : 107.3034
====> Epoch: 5 Average Loss: 107.7394
====> Test set loss : 106.7229
====> Epoch: 6 Average Loss: 107.0488
====> Test set loss : 106.1067
====> Epoch: 7 Average Loss: 106.6178
====> Test set loss : 105.8466
====> Epoch: 8 Average Loss: 106.1671
====> Test set loss : 105.6512
====> Epoch: 9 Average Loss: 105.8369
====> Test set loss : 105.1079
====> Epoch: 10 Average Loss: 105.5223
====> Test set loss : 105.2993
