In [13]:
import argparse
import os
import sys
import random

import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable

import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.fastest = True

In [14]:
def conv_block(in_dim, out_dim):
    return nn.Sequential(nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
                         nn.ELU(True),
                         nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
                         nn.ELU(True),
                         nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1, padding=0),
                         nn.AvgPool2d(kernel_size=2, stride=2))

def deconv_block(in_dim, out_dim):
    return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
                         nn.ELU(True),
                         nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
                         nn.ELU(True),
                         nn.UpsamplingNearest2d(scale_factor=2))

In [15]:
class Encoder(nn.Module):
    def __init__(self, n_ch, ndf, hidden_size):
        super(Encoder, self).__init__()

        # 256
        self.conv1 = nn.Sequential(nn.Conv2d(n_ch, ndf, kernel_size=3, stride=1, padding=1),
                                   nn.ELU(True))
        # 256
        self.conv2 = conv_block(ndf, ndf)
        # 128
        self.conv3 = conv_block(ndf, ndf*2)
        # 64
        self.conv4 = conv_block(ndf*2, ndf*3)
        # 32
        self.conv5 = conv_block(ndf*3, ndf*4)
        # 16
        #self.conv6 = conv_block(ndf*4, ndf*4)
        # 8
        self.encode = nn.Conv2d(ndf*4, hidden_size, kernel_size=8, stride=1, padding=0)
        # 1

    def forward(self, x):
        x = self.conv1(x) 
        x = self.conv2(x) 
        x = self.conv3(x) 
        x = self.conv4(x) 
        x = self.conv5(x) 
        #x = self.conv6(x) 
        x = self.encode(x) 
        return x

In [16]:
class Decoder(nn.Module):
    #Used as both decoder in Discriminator, as well as Generator.  They do NOT share weights however.
    def __init__(self, n_ch, ngf, hidden_size, condition=False, condition_size=0):
        super(Decoder, self).__init__()
        self.condition = condition

        self.decode_cond = nn.ConvTranspose2d(condition_size, ngf, kernel_size=8,stride=1,padding=0)
        # 1
        self.decode = nn.ConvTranspose2d(hidden_size, ngf, kernel_size=8,stride=1,padding=0)
        # 8
        self.dconv6 = deconv_block(ngf*2, ngf)
        # 16
        self.dconv5 = deconv_block(ngf, ngf)
        # 32
        self.dconv4 = deconv_block(ngf, ngf)
        # 64 
        self.dconv3 = deconv_block(ngf, ngf)
        # 128 
        #self.dconv2 = deconv_block(ngf, ngf)
        # 256
        self.dconv1 = nn.Sequential(nn.Conv2d(ngf,ngf,kernel_size=3,stride=1,padding=1),
                                    nn.ELU(True),
                                    nn.Conv2d(ngf,ngf,kernel_size=3,stride=1,padding=1),
                                    nn.ELU(True),
                                    nn.Conv2d(ngf, n_ch,kernel_size=3, stride=1,padding=1),
                                    nn.Tanh())

    def forward(self, x, condition_vec=None):
        x = self.decode(x)
        if self.condition:
            # NOTE: embedding condition vector
            x_cond = self.decode_cond(condition_vec)
            # NOTE: concatenation of z and condition vector
            x = torch.cat([x, x_cond], 1)
        x = self.dconv6(x) 
        x = self.dconv5(x) 
        x = self.dconv4(x) 
        x = self.dconv3(x) 
        #x = self.dconv2(x) 
        x = self.dconv1(x) 
        return x

In [23]:
class Discriminator(nn.Module):
    def __init__(self, n_ch, ndf, ngf, hidden_size, condition=False, condition_size=0):
        super(D, self).__init__()

        self.encoder = Encoder(n_ch, ndf, hidden_size)
        self.decoder = Decoder(n_ch, ngf, hidden_size, condition, condition_size)

    def forward(self, x, condition_vec=None):
        h = self.encoder(x)
        # NOTE injecting condition vector in Decoder
        out = self.decoder(h, condition_vec)
        return out

In [22]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--dataset', default='folder',  help='dataset name (It does not need to be modified)')
# parser.add_argument('--dataroot', default='', help='path to trn dataset')
# parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
# parser.add_argument('--valBatchSize', type=int, default=64, help='input batch size')
# parser.add_argument('--originalSize', type=int, default=142, help='the height / width of the original input image')
# parser.add_argument('--imageSize', type=int, default=128, help='the height / width of the cropped input image to network')
# parser.add_argument('--inputChannelSize', type=int, default=3, help='size of the input channels')
# parser.add_argument('--outputChannelSize', type=int, default=3, help='size of the output channels')
# parser.add_argument('--ngf', type=int, default=128)
# parser.add_argument('--ndf', type=int, default=128)
# parser.add_argument('--hidden_size', type=int, default=64, help='bottleneck dimension of Discriminator')
# parser.add_argument('--cond_size', type=int, default=2,  help='Whether to use conditional GAN')
# parser.add_argument('--niter', type=int, default=30, help='number of epochs to train for')
# parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate')
# parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate')
# parser.add_argument('--annealStart', type=int, default=0, help='annealing learning rate start to')
# parser.add_argument('--annealEvery', type=int, default=30, help='epoch to reaching at learning rate of 0')
# parser.add_argument('--lambda_k', type=float, default=0.001, help='learning rate of k')
# parser.add_argument('--gamma', type=float, default=0.7, help='balance bewteen D and G')
# parser.add_argument('--wd', type=float, default=0.0000, help='weight decay in D')
# parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
# parser.add_argument('--netG', default='', help="path to netG (to continue training)")
# parser.add_argument('--netD', default='', help="path to netD (to continue training)")
# parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
# parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints')
# parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs')
# parser.add_argument('--evalIter', type=int, default=4000, help='interval for evauating(generating) images from valDataroot')
# opt = parser.parse_args()
# print(opt)

# create_exp_dir(opt.exp)

manual_seed = 100
#opt.manualSeed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
np.random.seed(manual_seed)
print("Random Seed: ", manual_seed)

Random Seed:  100


In [24]:
dataset = 
dataroot = 
workers =

batch_size = 16
original_size = 142
image_size = 128
ngf = 
ndf = 

in_ch_size =
out_ch_size = 
cond_size = 

n_epochs = 100
lr_g = 0.00005
lr_d = 0.00005
beta1 = 
wd =  #weight decay

output_dir = 

g_weights_path = 
d_weights_path = 

SyntaxError: invalid syntax (<ipython-input-24-b9c45f0332e3>, line 1)

In [None]:
# NOTE get dataloader
dataloader = getLoader(dataset, 
                       dataroot, 
                       originalS_size, 
                       image_size, 
                       batchS_size, 
                       workers,
                       mean=(0.5, 0.5, 0.5),
                       std=(0.5, 0.5, 0.5), 
                       split='train',
                       shuffle=True, 
                       seed=manual_seed)

trainLogger = open('%s/train.log' % output_dir, 'w')

In [None]:
g_net = Decoder(in_ch_size, ngf, hidden_size, True, cond_size)
d_net = Discriminator(in_ch_size, ndf, ndf, hidden_size, True, cond_size)

g_net.apply(weights_init)
d_net.apply(weights_init)

if g_weights_path: g_net.load_state_dict(torch.load(gen_weights_path))
if d_weights_path: d_net.load_state_dict(torch.load(d_weights_path))

g_net.train()
d_net.train()

g_net.cuda()
d_net.cuda()

In [None]:
optim_g = optim.Adam(g_net.parameters(), lr = opt.lr_g, betas = (beta1, 0.999), weight_decay=0.0)
optim_d = optim.Adam(d_net.parameters(), lr = opt.lr_d, betas = (beta1, 0.999), weight_decay=wd)

In [None]:
input_d = torch.FloatTensor(batch_size, in_ch_size, image_size, image_size)
input_g = torch.FloatTensor(batch_size, hidden_size, 1, 1)
fixed_noise = torch.FloatTensor(opt.valBatchSize, hidden_size, 1, 1).uniform_(-1, 1)

input_d = input_d.cuda()
input_g = input_g.cuda()
fixed_noise = fixed_noise.cuda()

input_d = Variable(input_d)
input_g = Variable(input_g)
fixed_noise = Variable(fixed_noise, volatile=True)

In [None]:
# NOTE training loop
iters = 0
k = 0 # control how much emphasis is put on L(G(z_D)) during gradient descent.
M_global = AverageMeter()

In [None]:
for epoch in range(n_epochs):
    for i, data in enumerate(dataloader, 0):
        input_cpu, condition = data
        batch_size = input_cpu.size(0)

        input_cpu = input_cpu.cuda(async=True)
        input_d.data.resize_as_(input_cpu).copy_(input_cpu)
        
        # NOTE generate condition vector whose value is 1 or -1 (one-hot)
        cond = Variable(torch.FloatTensor(input_cpu.size(0), cond_size, 1, 1).fill_(-1).cuda(async=True))
        for idx in range(batch_size):
            if condition[idx] == 0: 
                cond.data[idx,0,0,0] = 1
                #cond.data[idx,1,0,0] = -1
            else:
                cond.data[idx,1,0,0] = 1
                #cond.data[idx,0,0,0] = -1

        #######################
        # Train Discriminator #
        #######################
        for p in d_net.parameters(): p.requires_grad = True 
        d_net.zero_grad()

        input_g.data.resize_(batch_size, hidden_size, 1, 1).uniform_(-1, 1)
        recon_real = d_net(input_d, cond)
        gen = g_net(input_g, cond)
        gen = gen.detach()
        recon_gen = d_net(gen, cond)

        err_d_real = torch.mean(torch.abs(recon_real - input_d))
        err_d_gen = torch.mean(torch.abs(recon_gen - gen))
        errD = errD_real - k * errD_gen
        errD.backward()
        optim_d.step()

        #######################
        # Train Generator #
        #######################
        for p in d_net.parameters(): p.requires_grad = False
        g_net.zero_grad()

        # NOTE compute L_G
        input_g.data.resize_(batch_size, opt.hidden_size, 1, 1).uniform_(-1, 1)
        gen = g_net(input_g, cond)
        recon_gen = d_net(gen, cond)
        errG = torch.mean(torch.abs(recon_gen - gen))
        errG.backward()
        optim_g.step()
        iters += 1

        # NOTE compute k_t and M_global
        balance = (opt.gamma * errD_real - errD_gen).data[0]
        k = min(max(k + opt.lambda_k * balance, 0), 1)
        measure = errD_real.data[0] + np.abs(balance)
        M_global.update(measure, input_d.size(0))

        # logging
        if iters % opt.display == 0:
            print('[%d/%d][%d/%d] Ld: %f Lg: %f, M_global: %f(%f), K: %f, balance.: %f lr: %f'
                  % (epoch, n_epochs, i, len(dataloader),
                     errD.data[0], errG.data[0],
                     measure, M_global.avg, k, balance,
                     optimizerG.param_groups[0]['lr']))
            sys.stdout.flush()
            trainLogger.write('%d\t%f\t%f\t%f\t%f\t%f\t%f\n' % \
                              (i, errD.data[0], errG.data[0], measure, M_global.avg, k, balance))
            trainLogger.flush()
        if iters % opt.evalIter == 0:
            cond = Variable(
                torch.FloatTensor(opt.valBatchSize, cond_size, 1, 1).fill_(-1).cuda(async=True), 
                volatile=True)
            for idx in range(opt.valBatchSize):
                if np.random.uniform(0,1) > 0.5:
                    cond.data[idx,0,0,0] = 1
                    #cond.data[idx,1,0,0] = -1
                else:
                    cond.data[idx,1,0,0] = 1
                    #cond.data[idx,0,0,0] = -1
            gen = g_net(fixed_noise, cond)
            recon_gen = d_net(gen, cond)
            vutils.save_image(gen.data, '%s/epoch_%08d_iter%08d_gen.png' % \
                              (opt.exp, epoch, iters), nrow=8, normalize=True)
            vutils.save_image(recon_gen.data, '%s/epoch_%08d_iter%08d_gen_recon.png' % \
                              (opt.exp, epoch, iters), nrow=8, normalize=True)
            vutils.save_image(input_d.data, '%s/epoch_%08d_iter%08d_real.png' % \
                              (opt.exp, epoch, iters), nrow=8, normalize=True)
            vutils.save_image(recon_real.data, '%s/epoch_%08d_iter%08d_real_recon.png' % \
                              (opt.exp, epoch, iters), nrow=8, normalize=True)

  # learning rate annealing
  if epoch >= opt.annealStart:
    adjust_learning_rate(optimizerD, opt.lrD, epoch, opt.annealEvery)
    adjust_learning_rate(optimizerG, opt.lrG, epoch, opt.annealEvery)

  # do checkpointing
  torch.save(g_net.state_dict(), '%s/G_epoch_%d.pth' % (opt.exp, epoch))
  torch.save(d_net.state_dict(), '%s/D_epoch_%d.pth' % (opt.exp, epoch))