In [1]:
import os, sys
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

imageSize = 64
batch_size = 64

trans = transforms.Compose([transforms.Resize(imageSize),
                            transforms.CenterCrop(imageSize),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

dataset = 'CIFAR10'

if dataset == 'CIFAR10':
    train_set = datasets.CIFAR10('../datasets/cifar10', train=True, download=True, transform=trans)
elif dataset == 'STL10':
    train_set = datasets.FashionMNIST('../datasets/stl10', train=True, download=True, transform=trans)


train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))

Files already downloaded and verified
==>>> total trainning batch number: 782


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

device = torch.device("cuda:0")
ngpu = 1 # number of gpu to use
nz = 100 # size of the latent z vector int(opt.nz)
ngf = 64
ndf = 64
nc = 3

In [3]:
class CVAE(nn.Module):
    def __init__(self, latent_size, num_classes):
        super(CVAE, self).__init__()
        
        self.latent_size = latent_size
        self.num_classes = num_classes
        
        # encoder
        self.enc = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, nz, 4, 1, 0, bias=False)
        )
        
        self.h_to_mu = nn.Linear(latent_size, latent_size)
        self.h_to_logvar = nn.Sequential(nn.Linear(latent_size, latent_size),
                                    nn.Sigmoid())

        # decoder
        self.dec = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     latent_size + num_classes, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
        
    def encode(self, xb):
        h = self.enc(xb)
        h = h.view(-1, self.latent_size)
        z_mu = self.h_to_mu(h)
        z_logvar = self.h_to_logvar(h)
        return z_mu, z_logvar
    
    def decode(self, zb):
        zb = zb.view(-1, nz + self.num_classes, 1, 1)
        xb_hat = self.dec(zb)
        return xb_hat
        
    def reparametrize(self, mu, logvar):
        std = torch.sqrt(torch.exp(logvar))
        eps = torch.cuda.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def forward(self, xb, yb):
        zb_mu, zb_logvar = self.encode(xb)
        zb = self.reparametrize(zb_mu, zb_logvar)
        
        zyb = torch.cat([zb, yb], dim=1)
        
        xb_hat = self.decode(zyb)
        return xb_hat, zb_mu, zb_logvar

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

num_classes = 10

model = CVAE(nz, num_classes)
model.cuda()

CVAE(
  (enc): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(512, 100, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
  (h_to_mu): Linear(in_features=100, out_features=100, bias=True)
  (h_to_logvar

In [5]:
def generate_digit_vector(zb, digit, num_classes):
    target_digit = digit * torch.ones(z_samples.size(0)).type(torch.LongTensor)
    target_digit = torch.eye(num_classes)[target_digit]
    target_digit = Variable(target_digit).cuda()
    return torch.cat([zb, target_digit], dim=1)

mse_loss = nn.MSELoss(size_average=False)

def kl_loss(mu, logvar):
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element, dim=1)
    KLD = torch.mean(KLD).mul_(-0.5)
    return KLD

optimizer = optim.Adam(model.parameters(), lr=0.001)

kl_weight = 0.
kl_steps = 1. / 5000.

with open('logs/loss.cvae.log', 'w') as log_fn:
    
    log_fn.write('epoch,reconstr_error,kldiv_error,num_batches\n')

    num_epochs = 50

    z_samples = torch.randn(24, nz)
    z_samples = Variable(z_samples).cuda()
    
    for epoch in range(num_epochs):
        acc_loss = []
        for xb, yb in tqdm(train_loader):
            xb = Variable(xb).cuda()
            batch_size = xb.size(0)
            
            yb = torch.eye(num_classes)[yb-1]
            yb = Variable(yb).cuda()
            
            kl_weight += kl_steps
            kl_weight = min(1.0, kl_weight)
            
            optimizer.zero_grad()
            xb_hat, zb_mu, zb_logvar = model(xb, yb)
            reconstr_loss = mse_loss(xb_hat, xb)
            kldiv_loss = kl_weight * kl_loss(zb_mu, zb_logvar)
            loss = reconstr_loss + kldiv_loss

            log_fn.write("{},{},{},{}\n".format(epoch, reconstr_loss.item(), kldiv_loss.item(), batch_size))
            
            loss.backward()
            optimizer.step()
        
            acc_loss.append(loss.item())
            
        if epoch % 5 == 0:
            for digit in range(0, 10):
                test_images = model.decode(generate_digit_vector(z_samples, digit, num_classes))
                np.save('img/generated_img.cvae.{}.epoch{}'.format(digit, epoch), test_images.data.cpu().numpy())
        tqdm.write("epoch:{} loss: {:.4f}".format(epoch, np.mean(acc_loss)))


    

100%|██████████| 782/782 [00:55<00:00, 14.18it/s]
  0%|          | 2/782 [00:00<00:50, 15.37it/s]

epoch:0 loss: 53972.6931


100%|██████████| 782/782 [00:54<00:00, 14.48it/s]
  0%|          | 2/782 [00:00<00:48, 16.20it/s]

epoch:1 loss: 30267.3541


100%|██████████| 782/782 [00:54<00:00, 14.38it/s]
  0%|          | 2/782 [00:00<00:59, 13.20it/s]

epoch:2 loss: 25342.2752


100%|██████████| 782/782 [00:54<00:00, 14.44it/s]
  0%|          | 2/782 [00:00<00:52, 14.99it/s]

epoch:3 loss: 23054.3336


100%|██████████| 782/782 [00:54<00:00, 14.29it/s]
  0%|          | 2/782 [00:00<00:50, 15.52it/s]

epoch:4 loss: 21813.6391


100%|██████████| 782/782 [00:53<00:00, 14.55it/s]
  0%|          | 2/782 [00:00<00:49, 15.89it/s]

epoch:5 loss: 20903.7836


100%|██████████| 782/782 [00:55<00:00, 14.18it/s]
  0%|          | 2/782 [00:00<01:10, 11.07it/s]

epoch:6 loss: 20223.0635


100%|██████████| 782/782 [00:54<00:00, 14.24it/s]
  0%|          | 2/782 [00:00<00:54, 14.41it/s]

epoch:7 loss: 19597.3707


100%|██████████| 782/782 [00:55<00:00, 14.16it/s]
  0%|          | 2/782 [00:00<00:51, 15.08it/s]

epoch:8 loss: 19123.3391


100%|██████████| 782/782 [00:54<00:00, 14.39it/s]
  0%|          | 2/782 [00:00<00:52, 14.99it/s]

epoch:9 loss: 18694.0964


100%|██████████| 782/782 [00:54<00:00, 14.26it/s]
  0%|          | 2/782 [00:00<00:58, 13.41it/s]

epoch:10 loss: 18206.5472


100%|██████████| 782/782 [00:54<00:00, 14.36it/s]
  0%|          | 2/782 [00:00<00:51, 15.04it/s]

epoch:11 loss: 17955.5494


100%|██████████| 782/782 [00:53<00:00, 14.63it/s]
  0%|          | 2/782 [00:00<00:59, 13.02it/s]

epoch:12 loss: 17565.5346


100%|██████████| 782/782 [00:54<00:00, 14.38it/s]
  0%|          | 2/782 [00:00<00:46, 16.68it/s]

epoch:13 loss: 17349.0232


100%|██████████| 782/782 [00:54<00:00, 14.30it/s]
  0%|          | 2/782 [00:00<01:14, 10.54it/s]

epoch:14 loss: 17164.7787


100%|██████████| 782/782 [00:54<00:00, 14.42it/s]
  0%|          | 1/782 [00:00<01:21,  9.53it/s]

epoch:15 loss: 16897.5991


100%|██████████| 782/782 [00:54<00:00, 14.42it/s]
  0%|          | 2/782 [00:00<00:53, 14.53it/s]

epoch:16 loss: 16716.4351


100%|██████████| 782/782 [00:54<00:00, 14.31it/s]
  0%|          | 2/782 [00:00<01:08, 11.43it/s]

epoch:17 loss: 16479.3442


100%|██████████| 782/782 [00:54<00:00, 14.42it/s]
  0%|          | 2/782 [00:00<00:53, 14.60it/s]

epoch:18 loss: 16332.7161


100%|██████████| 782/782 [00:54<00:00, 14.23it/s]
  0%|          | 2/782 [00:00<01:11, 10.98it/s]

epoch:19 loss: 16135.4354


100%|██████████| 782/782 [00:54<00:00, 14.38it/s]
  0%|          | 2/782 [00:00<00:53, 14.44it/s]

epoch:20 loss: 15989.4553


100%|██████████| 782/782 [00:54<00:00, 14.38it/s]
  0%|          | 2/782 [00:00<00:57, 13.59it/s]

epoch:21 loss: 15899.8344


100%|██████████| 782/782 [00:54<00:00, 14.24it/s]
  0%|          | 2/782 [00:00<00:52, 14.86it/s]

epoch:22 loss: 15720.8883


100%|██████████| 782/782 [00:54<00:00, 14.38it/s]
  0%|          | 2/782 [00:00<01:02, 12.57it/s]

epoch:23 loss: 15596.7752


100%|██████████| 782/782 [00:53<00:00, 14.60it/s]
  0%|          | 2/782 [00:00<00:54, 14.27it/s]

epoch:24 loss: 15543.9935


100%|██████████| 782/782 [00:54<00:00, 14.44it/s]
  0%|          | 2/782 [00:00<01:03, 12.29it/s]

epoch:25 loss: 15361.8651


100%|██████████| 782/782 [00:54<00:00, 14.46it/s]
  0%|          | 2/782 [00:00<00:51, 15.01it/s]

epoch:26 loss: 15272.0650


100%|██████████| 782/782 [00:54<00:00, 14.33it/s]
  0%|          | 2/782 [00:00<00:53, 14.54it/s]

epoch:27 loss: 15115.8306


100%|██████████| 782/782 [00:54<00:00, 14.33it/s]
  0%|          | 2/782 [00:00<00:52, 14.75it/s]

epoch:28 loss: 15038.7777


100%|██████████| 782/782 [00:55<00:00, 14.14it/s]
  0%|          | 2/782 [00:00<00:57, 13.59it/s]

epoch:29 loss: 14931.1919


100%|██████████| 782/782 [00:54<00:00, 14.34it/s]
  0%|          | 2/782 [00:00<00:52, 14.85it/s]

epoch:30 loss: 14822.8893


100%|██████████| 782/782 [00:53<00:00, 14.57it/s]
  0%|          | 2/782 [00:00<00:53, 14.55it/s]

epoch:31 loss: 14747.3904


100%|██████████| 782/782 [00:55<00:00, 14.03it/s]
  0%|          | 2/782 [00:00<00:53, 14.57it/s]

epoch:32 loss: 14644.2226


100%|██████████| 782/782 [00:54<00:00, 14.29it/s]
  0%|          | 2/782 [00:00<01:10, 11.05it/s]

epoch:33 loss: 14544.3850


100%|██████████| 782/782 [00:54<00:00, 14.34it/s]
  0%|          | 2/782 [00:00<00:58, 13.25it/s]

epoch:34 loss: 14448.1163


100%|██████████| 782/782 [00:55<00:00, 14.14it/s]
  0%|          | 2/782 [00:00<00:52, 14.98it/s]

epoch:35 loss: 14417.7737


100%|██████████| 782/782 [00:54<00:00, 14.44it/s]
  0%|          | 2/782 [00:00<00:52, 14.98it/s]

epoch:36 loss: 14308.7313


100%|██████████| 782/782 [00:54<00:00, 14.32it/s]
  0%|          | 2/782 [00:00<00:52, 14.95it/s]

epoch:37 loss: 14284.3574


100%|██████████| 782/782 [00:54<00:00, 14.33it/s]
  0%|          | 2/782 [00:00<01:08, 11.34it/s]

epoch:38 loss: 14145.7330


100%|██████████| 782/782 [00:55<00:00, 14.05it/s]
  0%|          | 2/782 [00:00<00:54, 14.29it/s]

epoch:39 loss: 14116.2152


100%|██████████| 782/782 [00:54<00:00, 14.41it/s]
  0%|          | 2/782 [00:00<00:56, 13.82it/s]

epoch:40 loss: 14056.9724


100%|██████████| 782/782 [00:54<00:00, 14.37it/s]
  0%|          | 2/782 [00:00<00:52, 14.97it/s]

epoch:41 loss: 14003.3363


100%|██████████| 782/782 [00:54<00:00, 14.37it/s]
  0%|          | 2/782 [00:00<00:54, 14.42it/s]

epoch:42 loss: 13943.2894


100%|██████████| 782/782 [00:54<00:00, 14.48it/s]
  0%|          | 2/782 [00:00<00:53, 14.57it/s]

epoch:43 loss: 13837.0415


100%|██████████| 782/782 [00:54<00:00, 14.33it/s]
  0%|          | 2/782 [00:00<00:52, 14.81it/s]

epoch:44 loss: 13804.1402


100%|██████████| 782/782 [00:54<00:00, 14.30it/s]
  0%|          | 2/782 [00:00<00:58, 13.35it/s]

epoch:45 loss: 13752.6690


100%|██████████| 782/782 [00:54<00:00, 14.23it/s]
  0%|          | 2/782 [00:00<00:53, 14.53it/s]

epoch:46 loss: 13677.9805


100%|██████████| 782/782 [00:55<00:00, 14.07it/s]
  0%|          | 2/782 [00:00<00:52, 14.99it/s]

epoch:47 loss: 13640.6337


100%|██████████| 782/782 [00:53<00:00, 14.59it/s]
  0%|          | 2/782 [00:00<01:11, 10.92it/s]

epoch:48 loss: 13605.4789


100%|██████████| 782/782 [00:55<00:00, 14.09it/s]

epoch:49 loss: 13539.5442





In [6]:
epoch = 50
for digit in range(0, num_classes):
    test_images = model.decode(generate_digit_vector(z_samples, digit, num_classes))
    np.save('img/generated_img.cvae.{}.epoch{}'.format(digit, epoch), test_images.data.cpu().numpy())