In [1]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torchvision.utils as vutils




In [2]:
% load_ext autoreload 
% autoreload 2

In [3]:
from modules_tied import VAE
from modules_tied import NetD
from modules_tied import Aux
from modules_tied import loss_function

In [4]:
bsz = 100
criterion = nn.BCELoss()


In [5]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=bsz, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=bsz, shuffle=True)


In [6]:
netG = VAE()
netD = NetD()
aux = Aux()

In [7]:
optimizerD = optim.Adam(netD.parameters(), lr=1e-4)
optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
optimizer_aux = optim.Adam(aux.parameters(), lr=1e-4)

In [8]:
input = torch.FloatTensor(bsz,28,28)
label = torch.FloatTensor(bsz)
real_label=1
fake_label=0
USE_CUDA=1
lamb = 2e-4
l1dist = nn.PairwiseDistance(1)
l2dist = nn.PairwiseDistance(2)
LeakyReLU = nn.LeakyReLU(0)

if(USE_CUDA):
    netG=netG.cuda()
    netD=netD.cuda()
    aux = aux.cuda()
    criterion=criterion.cuda()
    input,label=input.cuda(), label.cuda()
    l1dist = l1dist.cuda()
    l2dist = l2dist.cuda()
    LeakyReLU = LeakyReLU.cuda()

In [9]:
def get_direct_gradient_penalty(netD, x, gamma, cuda):
    _,output = netD(x)
    print(output.size())
    gradOutput = torch.ones(output.size()).cuda() if cuda else torch.ones(output.size())
    
    gradient = torch.autograd.grad(outputs=output, inputs=x, grad_outputs=gradOutput, create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradientPenalty = (gradient.norm(2, dim=1)).mean() * gamma
    
    return gradientPenalty


In [None]:
for epoch in range(10000):
    for i, (data,_) in enumerate(train_loader):
        gamma = 100
        real_cpu = data;

        real_cpu = real_cpu.cuda()
        input.resize_as_(real_cpu).copy_(real_cpu)
        label.resize_(bsz).fill_(real_label)

        dataSize = input.size(0)
        inputv = Variable(input,requires_grad=True)
        #irrelevant here 
        labelv = Variable(label)

        for p in netD.parameters():
            p.requires_grad = True
        
        #need variables for dis
        #x_l, x_l_tilde
        #do discriminator calculations
        netD.zero_grad()
        #fc3_weight,fc4_weight = aux.return_weights()
        mu,logvar = netG(inputv)
        std = logvar.mul(0.5).exp_()
        eps = Variable(std.data.new(std.size()).normal_())
        z=eps.mul(std).add_(mu)
        fake = aux(z)
        
        x_l_tilde, output_fake = netD(fake)
        x_l, output_real = netD(inputv)
        #x_l_aux, output_fake_aux = netD(fake_aux)
        
        pdist = l1dist(input.view(dataSize,-1), fake.view(dataSize,-1)).mul(lamb)
        
        print('pdist.size()',pdist.size())
        errD_fake = LeakyReLU(output_real - output_fake + pdist).mean()
        #print('errD_fake.size()',errD_fake.size())
        
        errD_fake.backward(retain_graph=True)
        
        #gradient penalty
        #need to set gamma 
        #print('inputv.size()',inputv.size())
        gp = get_direct_gradient_penalty(netD,inputv,10,True)
        gp.backward(retain_graph=True)
        
            
        
        #this becomes our modified criterion 
        #L_GAN_real = criterion(output_real, labelv)
        #L_GAN_real.backward(retain_graph=True)
        
        labelv = Variable(label.fill_(fake_label))
        #L_GAN_fake = criterion(output_fake, labelv)
        #L_GAN_fake.backward(retain_graph=True)
        
        z_p = Variable(std.data.new(std.size()).normal_())
        fake_aux = aux(z_p)
        x_l_aux, output_aux = netD(fake_aux)
        pdist_aux = l1dist(input.view(dataSize,-1),fake_aux.view(dataSize,-1)).mul(lamb)
        errD_aux = LeakyReLU(output_real - output_aux + pdist_aux).mean()
        #print('errD_aux.size()',errD_aux.size())
        errD_aux.backward(retain_graph=True)
        #L_GAN_aux = criterion(output_aux,labelv)
        #L_GAN_aux.backward(retain_graph=True)
        optimizerD.step()
           
        for p in netD.parameters():
            p.requires_grad = False
        
        #get weights of netG and use in aux
        aux.zero_grad()
        labelv=Variable(label.fill_(real_label))
        
        #L_dec_vae = loss_function(x_l_tilde,x_l,mu,logvar)
        L_dec_vae = gamma*F.mse_loss(x_l_tilde,x_l)
        #L_dec_fake = criterion(output_fake,labelv)
        #L_dec_aux  = criterion(output_aux,labelv)
        L_dec_fake = -errD_fake
        L_dec_aux  = -errD_aux 
        L_dec_vae.backward(retain_graph=True)
        L_dec_fake.backward(retain_graph=True)
        L_dec_aux.backward(retain_graph=True)
        optimizer_aux.step()
        
        

        #encoder loss 
        netG.zero_grad()
        L_enc = loss_function(x_l_tilde, x_l,mu,logvar)
        L_enc.backward()
        optimizerG.step()

    
        if i % 100 == 0:
            print('real_cpu.size()', real_cpu.size())
            vutils.save_image(real_cpu,
                            './real_samples.png',
                                normalize=True)
            vutils.save_image(fake.data.view(-1,1,28,28),
                                './fake_samples.png',
                                normalize=True)