# Conditional VAE-GAN

[![Package badge]][github]
[![Open In Colab]][notebook]

[github]:https://github.com/tarepan/VAE-GAN
[notebook]:https://colab.research.google.com/github/tarepan/VAE-GAN/blob/main/vaeGAN/cvaegan.ipynb
[Package badge]:https://img.shields.io/badge/GitHub-vaeagn-9cf.svg
[Open In Colab]:https://colab.research.google.com/assets/colab-badge.svg

## Setup

In [None]:
!git clone https://github.com/tarepan/VAE-GAN.git
%cd "./VAE-GAN/vaeGAN"

Cloning into 'VAE-GAN'...
remote: Enumerating objects: 34420, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 34420 (delta 0), reused 3 (delta 0), pack-reused 34413[K
Receiving objects: 100% (34420/34420), 118.19 MiB | 17.87 MiB/s, done.
Resolving deltas: 100% (57/57), done.


## Training

In [None]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from vaegan import VAE
from vaegan import NetD
from vaegan import Aux
from vaegan import loss_function


bsz = 128

dataset_train = datasets.MNIST('../data', download=True,                transform=transforms.ToTensor())
dataset_test  = datasets.MNIST('../data',                train = False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=bsz, shuffle=True)
test_loader  = torch.utils.data.DataLoader(dataset_test,  batch_size=bsz, shuffle=True)

netG = VAE()
netD = NetD()
aux = Aux()
criterion = nn.BCELoss()

optimizerD    = optim.Adam(netD.parameters(), lr=1e-4)
optimizerG    = optim.Adam(netG.parameters(), lr=1e-4)
optimizer_aux = optim.Adam(aux.parameters(),  lr=1e-4)

input = torch.FloatTensor(bsz,28,28)
label = torch.FloatTensor(bsz)
real_label=1
fake_label=0

netG=netG.cuda()
netD=netD.cuda()
aux = aux.cuda()
criterion=criterion.cuda()
input, label = input.cuda(), label.cuda()

%reload_ext autoreload
%autoreload 2

for epoch in range(200):
    for i, (data, y) in enumerate(train_loader):
        #### Step ################################################
        gamma = 1.0
        real_cpu = data;

        # Data
        real_cpu = real_cpu.cuda()
        y=y.cuda()
        input.resize_as_(real_cpu).copy_(real_cpu)
        label.resize_(bsz).fill_(real_label)
        inputv = Variable(input) # Digit image 28x28
        y      = Variable(y)     # Digit label
        labelv = Variable(label) # Real label

        netD.zero_grad()

        # Common_Forward
        ## Encode
        mu, logvar = netG(inputv, y)
        ## Sampling
        std = logvar.mul(0.5).exp_()
        eps = Variable(std.data.new(std.size()).normal_())
        z=eps.mul(std).add_(mu)
        ## Decode
        fake = aux(z, y)

        # D_Forward
        x_l_tilde, output_fake = netD(fake,   y)
        x_l,       output_real = netD(inputv, y)
        # D_Loss/Backward
        L_GAN_real = criterion(output_real.squeeze(1), labelv)
        L_GAN_real.backward(retain_graph=True)
        labelv = Variable(label.fill_(fake_label))
        L_GAN_fake = criterion(output_fake.squeeze(1), labelv)
        L_GAN_fake.backward(retain_graph=True)
        # Unconditional generation?
        z_p = Variable(std.data.new(std.size()).normal_())
        fake_aux = aux(z_p, y)
        x_l_aux, output_aux = netD(fake_aux, y)
        L_GAN_aux = criterion(output_aux.squeeze(1), labelv)
        L_GAN_aux.backward(retain_graph=True)
        # D_Optim
        optimizerD.step()

        aux.zero_grad()
        labelv=Variable(label.fill_(real_label))

        L_dec_vae = gamma * loss_function(x_l_tilde, x_l, mu, logvar)
        L_dec_fake = criterion(output_fake.squeeze(1), labelv)
        L_dec_aux  = criterion(output_aux.squeeze(1),  labelv)
        L_dec_vae.backward(retain_graph=True)
        L_dec_fake.backward(retain_graph=True)
        L_dec_aux.backward(retain_graph=True)
        optimizer_aux.step()

        #encoder loss
        netG.zero_grad()
        L_enc = loss_function(x_l_tilde, x_l, mu, logvar)
        L_enc.backward()
        optimizerG.step()

        # Logging
        if i % 100 == 0:
            print('real_cpu.size()', real_cpu.size(), "iteration: ", i)
            save_image(real_cpu,                   './results/cvaegan results/real_samples2.png', normalize=True)
            save_image(fake.data.view(-1,1,28,28), './results/cvaegan results/fake_samples2.png', normalize=True)
        #### /Step ###############################################

    if epoch % 25 == 0:
        print('epoch: ', epoch)
        save_image(fake.data.view(-1,1,28,28), '../results/cvaegan results/fake_samples2_{0}.png'.format(epoch), normalize=True)

# torch.save(netG, './pretrained models/netG3.pth')
# torch.save(netD, './pretrained models/netG3.pth')
# torch.save(aux, './pretrained models/netG3.pth')


## Inference

In [None]:
netG = torch.load('pretrained_models/netG2.pth')
netD = torch.load('pretrained_models/netD2.pth')
aux  = torch.load('pretrained_models/aux2.pth')

test_dataset = datasets.MNIST('../data', train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bsz, shuffle=True)

data, y = iter(test_loader).next()
save_image(data.view(-1,1,28,28), './fake.png', normalize=True)

In [None]:
%reload_ext autoreload
%autoreload 2

mu,logvar = netG(Variable(data).cuda(), Variable(y).cuda(), Variable(torch.tensor([8])).cuda(), .5)
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake = aux(z, y, Variable(torch.tensor([8])).cuda(), .5)
save_image(fake.data.view(-1,1,28,28), './results/cvae results/generated2.png', normalize=True)

mu,logvar = netG(Variable(fake), Variable(y).cuda())
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake2 = aux(z, y)
save_image(fake2.data.view(-1,1,28,28), './results/cvae results/generated3.png', normalize=True)

mu,logvar = netG(Variable(data).cuda(), Variable(y).cuda(), Variable(torch.tensor([8])).cuda(), 1)
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake = aux(z, y, Variable(torch.tensor([8])).cuda(), 1)
save_image(fake.data.view(-1,1,28,28), './results/cvae results/generated.png', normalize=True)