# Conditional VAE-GAN

[![Package badge]][github]
[![Open In Colab]][notebook]

[github]:https://github.com/tarepan/VAE-GAN
[notebook]:https://colab.research.google.com/github/tarepan/VAE-GAN/blob/main/vaeGAN/cvaegan.ipynb
[Package badge]:https://img.shields.io/badge/GitHub-vaeagn-9cf.svg
[Open In Colab]:https://colab.research.google.com/assets/colab-badge.svg

## Setup

In [None]:
!git clone https://github.com/tarepan/VAE-GAN.git
%cd "./VAE-GAN/vaeGAN"

Cloning into 'VAE-GAN'...
remote: Enumerating objects: 34420, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 34420 (delta 0), reused 3 (delta 0), pack-reused 34413[K
Receiving objects: 100% (34420/34420), 118.19 MiB | 17.87 MiB/s, done.
Resolving deltas: 100% (57/57), done.


## Training

In [None]:
!mkdir ../results/cvaegan_results

In [None]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from vaegan import VAE
from vaegan import NetD
from vaegan import Aux
from vaegan import loss_function


bsz = 128

dataset_train = datasets.MNIST('../data', download=True,                transform=transforms.ToTensor())
dataset_test  = datasets.MNIST('../data',                train = False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=bsz, shuffle=True, drop_last=True)
test_loader  = torch.utils.data.DataLoader(dataset_test,  batch_size=bsz, shuffle=True, drop_last=True)

encoder = VAE()
decoder = Aux()
disc    = NetD()
criterion = nn.BCELoss()

optim_disc = optim.Adam(disc.parameters(), lr=1e-4)
optim_enc  = optim.Adam(encoder.parameters(), lr=1e-4)
optim_dec  = optim.Adam(decoder.parameters(),  lr=1e-4)

encoder, decoder = encoder.cuda(), decoder.cuda()
disc = disc.cuda()
criterion =criterion.cuda()

ones  = torch.ones(bsz).cuda()
zeros = torch.zeros(bsz).cuda()

%reload_ext autoreload
%autoreload 2

for epoch in range(200):
    print(f'start epoch #{epoch}')

    for i, (data, y) in enumerate(train_loader):
        #### Step ################################################
        # Data - Digit image 28x28 & Digit label
        real, y = data.cuda(), y.cuda()

        # Common_Forward
        ## Encode
        mu, logvar = encoder(real, y)
        ## Sampling - Posterior and Prior
        ### Prior
        z_p = torch.empty_like(mu).normal_()
        ### Posterior
        eps = torch.empty_like(mu).normal_()
        z_q = mu + eps * torch.exp(logvar * 0.5)
        ## Decode
        fake     = decoder(z_q, y)
        fake_aux = decoder(z_p, y)

        # D_Forward
        d_feat_fake, d_fake     = disc(fake,     y)
        d_feat_real, d_real     = disc(real,     y)
        _,           d_fake_aux = disc(fake_aux, y)
        # D_Loss/Backward/Optim
        loss_adv_d_real = criterion(d_real.squeeze(1), ones)
        loss_adv_d_fake = criterion(d_fake.squeeze(1), zeros)
        loss_adv_d = loss_adv_d_real + loss_adv_d_fake
        loss_adv_d_fake_aux = criterion(d_fake_aux.squeeze(1), zeros)
        loss_aux = loss_adv_d_fake_aux
        loss_disc = loss_adv_d + loss_aux
        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        optim_disc.step()

        # # G_Loss, without adversarial encoder learning
        # loss_vae = loss_function(d_feat_fake, d_feat_real, mu, logvar)

        # # Dec_Loss/Backward/Optim
        # loss_adv_g     = criterion(d_fake.squeeze(1),     ones)
        # loss_adv_g_aux = criterion(d_fake_aux.squeeze(1), ones)
        # loss_dec = loss_adv_g + loss_adv_g_aux + loss_vae
        # decoder.zero_grad()
        # loss_dec.backward(retain_graph=True)
        # optim_dec.step()

        # # Enc_Loss/Backward/Optim
        # loss_enc = loss_vae
        # encoder.zero_grad()
        # loss_enc.backward()
        # optim_enc.step()

        # G_Loss/Backward/Optim, with adversarial encoder learning
        loss_adv_g_zq = criterion(d_fake_zq.squeeze(1), ones)
        loss_adv_g_zp = criterion(d_fake_zp.squeeze(1), ones)
        loss_vae = loss_function(d_feat_fake, d_feat_real, mu, logvar)
        loss_g = loss_adv_g_zq + loss_adv_g_zp + loss_vae
        encoder.zero_grad()
        decoder.zero_grad()
        loss_g.backward()
        optim_dec.step()
        optim_enc.step()

        # Logging
        if i % 2000 == 0:
            save_image(real,                       '../results/cvaegan_results/train2_real_samples2.png', normalize=True)
            save_image(fake.data.view(-1,1,28,28), '../results/cvaegan_results/train2_fake_samples2.png', normalize=True)
        #### /Step ###############################################

    if epoch % 25 == 0:
        save_image(fake.data.view(-1,1,28,28), '../results/cvaegan_results/train2_fake_samples2_{0}.png'.format(epoch), normalize=True)

# torch.save(encoder, './pretrained models/encoder3.pth')
# torch.save(disc,    './pretrained models/disc3.pth')
# torch.save(decoder, './pretrained models/decoder3.pth')


## Inference

In [None]:
netG = torch.load('pretrained_models/netG2.pth')
netD = torch.load('pretrained_models/netD2.pth')
aux  = torch.load('pretrained_models/aux2.pth')

test_dataset = datasets.MNIST('../data', train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bsz, shuffle=True)

data, y = iter(test_loader).next()
save_image(data.view(-1,1,28,28), './fake.png', normalize=True)

In [None]:
%reload_ext autoreload
%autoreload 2

mu,logvar = netG(Variable(data).cuda(), Variable(y).cuda(), Variable(torch.tensor([8])).cuda(), .5)
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake = aux(z, y, Variable(torch.tensor([8])).cuda(), .5)
save_image(fake.data.view(-1,1,28,28), './results/cvae results/generated2.png', normalize=True)

mu,logvar = netG(Variable(fake), Variable(y).cuda())
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake2 = aux(z, y)
save_image(fake2.data.view(-1,1,28,28), './results/cvae results/generated3.png', normalize=True)

mu,logvar = netG(Variable(data).cuda(), Variable(y).cuda(), Variable(torch.tensor([8])).cuda(), 1)
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
z=eps.mul(std).add_(mu)
fake = aux(z, y, Variable(torch.tensor([8])).cuda(), 1)
save_image(fake.data.view(-1,1,28,28), './results/cvae results/generated.png', normalize=True)