In [1]:
import os, sys
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = datasets.MNIST('../datasets/mnist', train=True, download=True, transform=trans)
test_set = datasets.MNIST('../datasets/mnist', train=False, download=True, transform=trans)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))

==>>> total trainning batch number: 600


In [2]:
class Reshape(nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape
 
    def forward(self, input):
        return input.view(*self.shape)
    
class CVAE(nn.Module):
    def __init__(self, latent_size):
        super(CVAE, self).__init__()
        
        self.latent_size = latent_size
            
        # encoder
        self.enc = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
                    nn.LeakyReLU(inplace=True),
                    torch.nn.ZeroPad2d((0, 1, 0, 1)),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    Reshape(-1, 64*4*4))
        
        self.h_to_mu = nn.Linear(64*4*4, latent_size)
        self.h_to_logvar = nn.Sequential(nn.Linear(64*4*4, latent_size),
                                         nn.Sigmoid())

        # discriminator
#         self.d = nn.Sequential(nn.Conv2d(1, latent_dim, kernel_size=3, stride=2, padding=1),
#                     nn.BatchNorm2d(64),
#                     nn.LeakyReLU(inplace=True),
#                     nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
#                     nn.BatchNorm2d(128),
#                     nn.LeakyReLU(inplace=True),
#                     Reshape(-1, 128*7*7),
#                     nn.Linear(128*7*7, 128),
#                     nn.LeakyReLU(inplace=True),
#                     nn.Linear(128, 10),
#                     nn.Softmax(dim=1))
        
        # decoder
        self.dec = nn.Sequential(nn.Linear(latent_size + 10, 64*4*4),
                                 nn.LeakyReLU(),
                                 nn.Linear(64*4*4, 64*4*4),
                    Reshape(-1, 64, 4, 4),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
                    nn.LeakyReLU(inplace=True),
                    torch.nn.ZeroPad2d((0, -1, 0, -1)),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),
                    nn.Tanh())
        
    def encode(self, xb):
        h = self.enc(xb)
        z_mu = self.h_to_mu(h)
        z_logvar = self.h_to_logvar(h)
        return z_mu, z_logvar
    
    def decode(self, zb):
        xb_hat = self.dec(zb)
        return xb_hat
        
    def reparametrize(self, mu, logvar):
        std = torch.sqrt(torch.exp(logvar))
        eps = torch.cuda.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def forward(self, xb, yb):
        zb_mu, zb_logvar = self.encode(xb)
        zb = self.reparametrize(zb_mu, zb_logvar)
        
        zyb = torch.cat([zb, yb], dim=1)
        
        xb_hat = self.decode(zyb)
        return xb_hat, zb_mu, zb_logvar

In [3]:
latent_size = 64
os.environ["CUDA_VISIBLE_DEVICES"]="1"

model = CVAE(latent_size)
model.cuda()

CVAE(
  (enc): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01, inplace)
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): LeakyReLU(negative_slope=0.01, inplace)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.01, inplace)
    (6): ZeroPad2d(padding=(0, 1, 0, 1), value=0)
    (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.01, inplace)
    (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): LeakyReLU(negative_slope=0.01, inplace)
    (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.01, inplace)
    (13): Reshape()
  )
  (h_to_mu): Linear(in_features=1024, out_features=64, bias=True)
  (h_to_logvar): Sequential(
    (0): Linear(in_features=1024, out_features=64, bias=Tr

In [4]:
def generate_digit_vector(zb, digit, num_classes=10):
    target_digit = digit * torch.ones(z_samples.size(0)).type(torch.LongTensor)
    target_digit = torch.eye(num_classes)[target_digit]
    target_digit = Variable(target_digit).cuda()
    return torch.cat([zb, target_digit], dim=1)

mse_loss = nn.MSELoss(size_average=False)

def kl_loss(mu, logvar):
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element, dim=1)
    KLD = torch.mean(KLD).mul_(-0.5)
    return KLD

optimizer = optim.Adam(model.parameters(), lr=0.001)

kl_weight = 0.
kl_steps = 1. / 2000.

with open('loss.cvae.log', 'w') as log_fn:
    
    log_fn.write('epoch,reconstr_error,kldiv_error,num_batches\n')

    num_epochs = 50
    num_classes = 10
    
    z_samples = torch.randn(16, latent_size)
    z_samples = Variable(z_samples).cuda()
    
    for epoch in range(num_epochs):
        for xb, yb in tqdm(train_loader):
            xb = Variable(xb).cuda()
            batch_size = xb.size(0)
            
            yb = torch.eye(num_classes)[yb]
            yb = Variable(yb).cuda()
            
            kl_weight += kl_steps
            kl_weight = min(1.0, kl_weight)
            
            optimizer.zero_grad()
            xb_hat, zb_mu, zb_logvar = model(xb, yb)
            reconstr_loss = mse_loss(xb_hat, xb)
            kldiv_loss = kl_weight * kl_loss(zb_mu, zb_logvar)
            loss = reconstr_loss + kldiv_loss

            log_fn.write("{},{},{},{}\n".format(epoch, reconstr_loss.item(), kldiv_loss.item(), batch_size))
            
            loss.backward()
            optimizer.step()
        
        if epoch % 10 == 0:
            for digit in range(0, 10):
                test_images = model.decode(generate_digit_vector(z_samples, digit))
                np.save('img/generated_img.cvae.{}.epoch{}'.format(digit, epoch), test_images.data.cpu().numpy())
            tqdm.write("loss: {:.4f}".format(loss.item()))

epoch = 50
for digit in range(0, 10):
    test_images = model.decode(generate_digit_vector(z_samples, digit))
    np.save('img/generated_img.cvae.{}.epoch{}'.format(digit, epoch), test_images.data.cpu().numpy())
    

100%|██████████| 600/600 [00:12<00:00, 46.42it/s]
  1%|          | 6/600 [00:00<00:10, 55.91it/s]

loss: 1282.8469


100%|██████████| 600/600 [00:11<00:00, 54.14it/s]
100%|██████████| 600/600 [00:11<00:00, 53.02it/s]
100%|██████████| 600/600 [00:11<00:00, 53.07it/s]
100%|██████████| 600/600 [00:11<00:00, 54.46it/s]
100%|██████████| 600/600 [00:11<00:00, 53.15it/s]
100%|██████████| 600/600 [00:11<00:00, 52.86it/s]
100%|██████████| 600/600 [00:11<00:00, 54.43it/s]
100%|██████████| 600/600 [00:11<00:00, 54.08it/s]
100%|██████████| 600/600 [00:11<00:00, 53.60it/s]
100%|██████████| 600/600 [00:11<00:00, 53.63it/s]
  1%|          | 6/600 [00:00<00:10, 54.90it/s]

loss: 820.7140


100%|██████████| 600/600 [00:11<00:00, 54.42it/s]
100%|██████████| 600/600 [00:10<00:00, 55.47it/s]
100%|██████████| 600/600 [00:10<00:00, 54.90it/s]
100%|██████████| 600/600 [00:10<00:00, 54.76it/s]
100%|██████████| 600/600 [00:10<00:00, 54.58it/s]
100%|██████████| 600/600 [00:10<00:00, 54.65it/s]
100%|██████████| 600/600 [00:10<00:00, 56.05it/s]
100%|██████████| 600/600 [00:10<00:00, 56.40it/s]
100%|██████████| 600/600 [00:11<00:00, 54.31it/s]
100%|██████████| 600/600 [00:10<00:00, 55.68it/s]
  1%|          | 6/600 [00:00<00:10, 54.10it/s]

loss: 785.8838


100%|██████████| 600/600 [00:10<00:00, 54.72it/s]
100%|██████████| 600/600 [00:10<00:00, 54.63it/s]
100%|██████████| 600/600 [00:11<00:00, 53.96it/s]
100%|██████████| 600/600 [00:11<00:00, 54.22it/s]
100%|██████████| 600/600 [00:11<00:00, 54.54it/s]
100%|██████████| 600/600 [00:10<00:00, 55.99it/s]
100%|██████████| 600/600 [00:11<00:00, 54.49it/s]
100%|██████████| 600/600 [00:11<00:00, 54.30it/s]
100%|██████████| 600/600 [00:11<00:00, 54.27it/s]
100%|██████████| 600/600 [00:11<00:00, 54.44it/s]
  1%|          | 6/600 [00:00<00:11, 53.99it/s]

loss: 696.4124


100%|██████████| 600/600 [00:11<00:00, 54.52it/s]
100%|██████████| 600/600 [00:11<00:00, 54.17it/s]
100%|██████████| 600/600 [00:11<00:00, 54.48it/s]
100%|██████████| 600/600 [00:10<00:00, 54.69it/s]
100%|██████████| 600/600 [00:11<00:00, 52.78it/s]
100%|██████████| 600/600 [00:11<00:00, 54.41it/s]
100%|██████████| 600/600 [00:11<00:00, 52.21it/s]
100%|██████████| 600/600 [00:11<00:00, 54.22it/s]
100%|██████████| 600/600 [00:11<00:00, 53.67it/s]
100%|██████████| 600/600 [00:11<00:00, 54.09it/s]
  1%|          | 6/600 [00:00<00:11, 52.34it/s]

loss: 673.1128


100%|██████████| 600/600 [00:10<00:00, 54.62it/s]
100%|██████████| 600/600 [00:10<00:00, 56.07it/s]
100%|██████████| 600/600 [00:10<00:00, 56.10it/s]
100%|██████████| 600/600 [00:10<00:00, 56.01it/s]
100%|██████████| 600/600 [00:10<00:00, 55.42it/s]
100%|██████████| 600/600 [00:10<00:00, 54.85it/s]
100%|██████████| 600/600 [00:10<00:00, 56.35it/s]
100%|██████████| 600/600 [00:10<00:00, 56.24it/s]
100%|██████████| 600/600 [00:10<00:00, 55.42it/s]
