In [None]:
import torch
from torch.autograd import Variable
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import cv2
import pickle
from random import *
import gc
import torch.optim as optim
from time import gmtime, strftime
import os
from PIL import Image
import data_transform
import torchvision.utils
%matplotlib inline

In [None]:
ts_ = strftime("%Y-%m-%d__%Hh%Mm%Ss", gmtime())
n_desc_per_img = 10

In [None]:
loss_criterion = torch.nn.BCELoss()

def KL_Loss(mu, sigma):
    kl_ele = mu.pow(2).add_(sigma.pow(2)).mul_(-1).add_(1).add_(sigma.pow(2).log_())
    kl = torch.mean(kl_ele).mul_(-0.5)
    return kl

def loss_gen(disc_fake, mu, sigma, kl_lambda, batch_size):
    g_fake = -loss_criterion(disc_fake, Variable(torch.FloatTensor([0]*batch_size), requires_grad=False).cuda())
    kl_loss = KL_Loss(mu, sigma)
    g_total = g_fake + (kl_lambda*kl_loss)
    return g_total
    
def loss_disc(disc_real, disc_fake, disc_wrong, batch_size):
    d_real = loss_criterion(disc_real, Variable(torch.FloatTensor([1]*batch_size), requires_grad=False).cuda())
    d_fake = loss_criterion(disc_fake, Variable(torch.FloatTensor([0]*batch_size), requires_grad=False).cuda())
    d_wrong = loss_criterion(disc_wrong, Variable(torch.FloatTensor([0]*batch_size), requires_grad=False).cuda())
    d_total = d_real + ((d_fake + d_wrong)*0.5)
    return d_total

In [None]:
def upSample(c_in, c_out):
    mod = torch.nn.Sequential(
        torch.nn.Upsample(scale_factor=2, mode='nearest'),
        torch.nn.Conv2d(c_in, c_out, kernel_size = 3, stride = 1, padding = 1, bias = False),
        torch.nn.BatchNorm2d(c_out),
        torch.nn.ReLU(inplace=True))
    return mod

In [None]:
class cond_aug(torch.nn.Module):
    def __init__(self, embedding_dim=1024, cond_dim=128):
        super(cond_aug,self).__init__()
        self.embedding_dim = embedding_dim
        self.cond_dim = cond_dim
        
        self.fc_mu = torch.nn.Linear(self.embedding_dim, self.cond_dim)
        self.fc_sigma = torch.nn.Linear(self.embedding_dim, self.cond_dim)
        
        torch.nn.init.xavier_normal_(self.fc_mu.weight)
        torch.nn.init.xavier_normal_(self.fc_sigma.weight)
        self.relu = torch.nn.ReLU(inplace=True)
        
    def forward(self, x):
        batch_size = x.shape[0]
        mu = self.relu(self.fc_mu(x))
        logvar = self.relu(self.fc_sigma(x))
        sigma = logvar.mul(0.5).exp_()
        dist = np.random.multivariate_normal(np.zeros(self.cond_dim),np.eye(self.cond_dim))
        eps = Variable(torch.from_numpy(dist).type(torch.FloatTensor)).cuda().view(1,-1).repeat(batch_size,1)
        c = mu + (sigma * eps)
        return mu, sigma, c

In [None]:
class ResBlock(torch.nn.Module):
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(channel_num, channel_num, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(channel_num),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(channel_num, channel_num, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(channel_num))
        self.relu = torch.nn.ReLU(inplace=True)
        

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        out = self.relu(out)
        return out

In [None]:
class stage1_gen(torch.nn.Module):
    def __init__(self, embedding_dim=1024, cond_dim=128, noise_dim=100, ups_input_dim=1024):
        super(stage1_gen,self).__init__()
        self.embedding_dim = embedding_dim
        self.cond_dim = cond_dim
        self.noise_dim = noise_dim
        self.ups_input_dim = ups_input_dim
        self.conc_dim = self.cond_dim + self.noise_dim
        
        self.dist = np.random.multivariate_normal(np.zeros(self.noise_dim),np.eye(self.noise_dim))
        self.augm = cond_aug(self.embedding_dim, self.cond_dim)
        self.ups_input = torch.nn.Sequential(
                    torch.nn.Linear(self.conc_dim, self.ups_input_dim*4*4, bias=False),
                    torch.nn.BatchNorm1d(self.ups_input_dim*4*4),
                    torch.nn.ReLU(inplace=True))
        self.upsample1 = upSample(self.ups_input_dim,self.ups_input_dim//2)     
        self.upsample2 = upSample(self.ups_input_dim//2,self.ups_input_dim//4)
        self.upsample3 = upSample(self.ups_input_dim//4,self.ups_input_dim//8)
        self.upsample4 = upSample(self.ups_input_dim//8,self.ups_input_dim//16)
        self.gen_img = torch.nn.Sequential(
                torch.nn.Conv2d(self.ups_input_dim//16, 3, kernel_size = 3, stride = 1, padding = 1, bias = False),
                torch.nn.Tanh())
        
    def forward(self, x):
        batch_size = x.shape[0]
        z = Variable(torch.from_numpy(self.dist).type(torch.FloatTensor)).cuda().view(1,-1).repeat(batch_size,1)
        mu, sigma, c = self.augm(x)
        inp = torch.cat((c,z),1)
        
        x = self.ups_input(inp)
        x = x.view(-1,self.ups_input_dim,4,4)
        x = self.upsample1(x)
        x = self.upsample2(x)
        x = self.upsample3(x)
        x = self.upsample4(x)
        fake_img = self.gen_img(x)
        
        return fake_img, mu, sigma

In [None]:
class stage1_disc(torch.nn.Module):
    def __init__(self, cond_dim=128, down_dim=64):
        super(stage1_disc,self).__init__()
        self.cond_dim = cond_dim
        self.down_dim = down_dim
        
        self.enc_img = torch.nn.Sequential(
            torch.nn.Conv2d(3, self.down_dim, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim, self.down_dim*2, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*2),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*2, self.down_dim*4, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*4),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*4, self.down_dim*8, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*8),
            torch.nn.LeakyReLU(0.2, inplace=True))
        
        self.get_logits = torch.nn.Sequential(
            torch.nn.Conv2d(self.down_dim*8+self.cond_dim, self.down_dim*8, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*8, 1, kernel_size = 4, stride = 4),
            torch.nn.Sigmoid())
        
    def forward(self, image, cond_vec):
        x = self.enc_img(image)
        y = cond_vec.view(-1, self.cond_dim, 1, 1)
        y = y.repeat(1,1,4,4)
        z = torch.cat((x,y),1) # N x (128 + 512) x 4 x 4
        out = self.get_logits(z).view(-1)
        
        return out

In [None]:
class stage2_gen(torch.nn.Module):
    def __init__(self, down_dim=128, embedding_dim=1024, cond_dim=128, num_residuals=4):
        super(stage2_gen,self).__init__()
        self.down_dim = down_dim
        self.embedding_dim = embedding_dim
        self.cond_dim = cond_dim
        
        self.downsampler = torch.nn.Sequential(
            torch.nn.Conv2d(3, self.down_dim, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(self.down_dim, self.down_dim * 2, kernel_size = 4, stride = 2, padding = 1, bias=False),
            torch.nn.BatchNorm2d(self.down_dim * 2),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(self.down_dim * 2, self.down_dim * 4, kernel_size = 4, stride = 2, padding = 1, bias=False),
            torch.nn.BatchNorm2d(self.down_dim * 4),
            torch.nn.ReLU(inplace=True))
        
        self.augm = cond_aug(self.embedding_dim, self.cond_dim)
        
        self.joint_proc = torch.nn.Sequential(
            torch.nn.Conv2d(self.down_dim * 4 + self.cond_dim, self.down_dim * 4, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim * 4),
            torch.nn.ReLU(True))
        
        self.layers = []
        for i in range(num_residuals):
            self.layers.append(ResBlock(self.down_dim * 4))
        self.residual = torch.nn.Sequential(*self.layers)
        
        self.upsample1 = upSample(self.down_dim * 4, self.down_dim * 2)
        self.upsample2 = upSample(self.down_dim * 2, self.down_dim)
        self.upsample3 = upSample(self.down_dim, self.down_dim // 2)
        self.upsample4 = upSample(self.down_dim // 2, self.down_dim // 4)
        self.gen_img = torch.nn.Sequential(
            torch.nn.Conv2d(self.down_dim // 4, 3, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.Tanh())

        
        
    def forward(self, text_embedding, stage1_image):
        encoded_img = self.downsampler(stage1_image) # --> N x 512 x 16 x 16
        mu, sigma, c = self.augm(text_embedding)
        c = c.view(-1, self.cond_dim, 1, 1) # --> N x 128 x 1 x 1
        c = c.repeat(1, 1, 16, 16) # --> N x 128 x 16 x 16
        conc_inp = torch.cat([encoded_img, c], 1) # --> N x 640 x 16 x 16
        
        conc_out = self.joint_proc(conc_inp) # --> N x 512 x 16 x 16
        conc_out = self.residual(conc_out)   # --> N x 512 x 16 x 16
        
        conc_out = self.upsample1(conc_out) # --> N x 256 x 32 x 32
        conc_out = self.upsample2(conc_out) # --> N x 128 x 64 x 64
        conc_out = self.upsample3(conc_out) # --> N x 64 x 128 x 128
        conc_out = self.upsample4(conc_out) # --> N x 32 x 256 x 256

        fake_img = self.gen_img(conc_out) # --> N x 3 x 256 x 256
        return fake_img, mu, sigma

In [None]:
class stage2_disc(torch.nn.Module):
    def __init__(self, cond_dim=128, down_dim=64):
        super(stage2_disc,self).__init__()
        self.cond_dim = cond_dim
        self.down_dim = down_dim
        
        self.enc_img = torch.nn.Sequential(
            torch.nn.Conv2d(3, self.down_dim, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim, self.down_dim*2, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*2),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*2, self.down_dim*4, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*4),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*4, self.down_dim*8, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*8, self.down_dim*16, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*16),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*16, self.down_dim*32, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*32),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*32, self.down_dim*16, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim * 16),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*16, self.down_dim*8, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim * 8),
            torch.nn.LeakyReLU(0.2, inplace=True))   # 4 * 4 * ndf * 8)
        
        self.get_logits = torch.nn.Sequential(
            torch.nn.Conv2d(self.down_dim*8+self.cond_dim, self.down_dim*8, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*8, 1, kernel_size = 4, stride = 4),
            torch.nn.Sigmoid())
        
    def forward(self, image, cond_vec):
        x = self.enc_img(image)
        y = cond_vec.view(-1, self.cond_dim, 1, 1)
        y = y.repeat(1,1,4,4)
        z = torch.cat((x,y),1) # N x (128 + 512) x 4 x 4
        out = self.get_logits(z).view(-1)
        
        return out

In [None]:
def train_s1(gen, disc, train_s1_imgs, train_s1_emds, batch_size=64, epochs=10, eta=0.0001, opt=optim.Adam,
             model_name='A_Model_Has_No_Name', kl_lambda=2.0, embedding_dim=1024, cond_dim=128, GPU_List='0'):
    gen.cuda()
    disc.cuda()
    
    gpus = [int(x) for x in GPU_List.split(',')]
    
    optim_gen = opt(gen.parameters(), lr=eta)
    optim_disc = opt(disc.parameters(), lr=eta)
    
    print('Training Stage 1...')
    
    l_tr_g = []
    l_tr_d = []
    iter_tr = []
    
    desc_shape = train_s1_emds[0].shape
    img_shape = train_s1_imgs[0].shape
    
    file_path = "./observations/"
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    file_path = "./saved_models/"
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    ts = ts_ + "_S1_" + model_name
    
    real_pairs = []
    for i in range(len(train_s1_imgs)):
        for j in range(n_desc_per_img):
            real_pairs.append((i,i * n_desc_per_img + j))
    
    n_train = len(real_pairs)
    print('Training Size = ',n_train)
    
    for epoch in range(epochs):
        l_tr_g_temp = 0
        l_tr_d_temp = 0
        
        ############ Shuffle and Gen Fake Pairs ############
        shuffle(real_pairs)
        fake_pairs = []
        for i in range(n_train):
            nidx = randint(0,len(train_s1_emds)-1)
            while nidx == real_pairs[i][1]:
                nidx = randint(0,len(train_s1_emds)-1)
            fake_pairs.append((real_pairs[i][0],nidx))
        
        ############ Training #############
        start = 0
        end = 0
        while start < n_train:
            end = start + batch_size
            if end > n_train:
                end = n_train
            batch_size_ = end - start
            
            inp_img = [train_s1_imgs[x[0]] for x in real_pairs[start:end]]
            inp_text = [train_s1_emds[x[1]] for x in real_pairs[start:end]]
            inp_ftext = [train_s1_emds[x[1]] for x in fake_pairs[start:end]]
            
            real_img = Variable(torch.FloatTensor(inp_img).view(-1,img_shape[0],img_shape[1],img_shape[2])).type(torch.FloatTensor).cuda()
            real_text = Variable(torch.FloatTensor(inp_text).view(-1,desc_shape[0])).cuda()
            fake_text = Variable(torch.FloatTensor(inp_ftext).view(-1,desc_shape[0])).cuda()
            
            fake_img, mu, sigma = torch.nn.parallel.data_parallel(gen,(real_text),gpus)
            
            disc_real = torch.nn.parallel.data_parallel(disc,(real_img, mu),gpus)
            disc_fake = torch.nn.parallel.data_parallel(disc,(fake_img, mu),gpus)
            augm = cond_aug(embedding_dim, cond_dim)
            augm.cuda()
            mu_w, _, _ = augm(fake_text)
            disc_wrong = torch.nn.parallel.data_parallel(disc,(real_img, mu_w),gpus)
            
            optim_gen.zero_grad()
            optim_disc.zero_grad()
            
            ld_ = loss_disc(disc_real, disc_fake, disc_wrong, batch_size_)         
            ld_.backward(retain_graph=True)
            optim_disc.step()
            lg_ = loss_gen(disc_fake, mu, sigma, kl_lambda, batch_size_)
            lg_.backward(retain_graph=True)
            optim_gen.step()
            
#             print(gen.ups_input.weight.grad)
            
            l_tr_g_temp += float(lg_.data)
            l_tr_d_temp += float(ld_.data)
            
            del inp_img, inp_text, inp_ftext, real_img, real_text, fake_text, disc_real, disc_fake, augm, mu_w, disc_wrong, ld_, lg_, batch_size_
            
            start += batch_size
            
        l_tr_g.append(l_tr_g_temp/n_train)
        l_tr_d.append(l_tr_d_temp/n_train)
        iter_tr.append(epoch+1)
        l = open('./observations/Loss_S1_'+ts+'.txt','a+')
        l.write('Epoch ' + str(epoch) + ': Generator Loss = ' + str(l_tr_g_temp/n_train) + \
               '         Discriminator Loss = ' + str(l_tr_d_temp/n_train)  + '\n')
        print('Epoch ' + str(epoch) + ': Generator Loss = ' + str(l_tr_g_temp/n_train) + \
               '         Discriminator Loss = ' + str(l_tr_d_temp/n_train))
        l.close()
        
        if os.path.exists('saved_models/gen_s1_' + str(ts) + '.pt'):
            os.remove('saved_models/gen_s1_' + str(ts) + '.pt')
            os.remove('saved_models/disc_s1_' + str(ts) + '.pt')
        torch.save(gen,'saved_models/gen_s1_' + str(ts) + '.pt')
        torch.save(disc,'saved_models/disc_s1_' + str(ts) + '.pt')
    
    return l_tr_g, l_tr_d
        

In [None]:
def train_s2(gen, disc, gen_s1, train_s2_imgs, train_s2_emds, batch_size=64, epochs=10, eta=0.0001, opt=optim.Adam,
             model_name='A_Model_Has_No_Name', kl_lambda=2.0, embedding_dim=1024, cond_dim=128, GPU_List='0'):
    gen.cuda()
    disc.cuda()
    gen_s1.cuda()
    
    gpus = [int(x) for x in GPU_List.split(',')]
    
    optim_gen = opt(gen.parameters(), lr=eta)
    optim_disc = opt(disc.parameters(), lr=eta)
    
    print('Training Stage 2...')
    
    l_tr_g = []
    l_tr_d = []
    iter_tr = []
    
    desc_shape = train_s2_emds[0].shape
    img_shape = train_s2_imgs[0].shape
    
    file_path = "./observations/"
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    file_path = "./saved_models/"
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    ts = ts_ + "_S1_" + model_name
    
    real_pairs = []
    for i in range(len(train_s2_imgs)):
        for j in range(n_desc_per_img):
            real_pairs.append((i,i * n_desc_per_img + j))
    
    n_train = len(real_pairs)
    print('Training Size = ',n_train)
    
    for epoch in range(epochs):
        l_tr_g_temp = 0
        l_tr_d_temp = 0
        
        ############ Shuffle and Gen Fake Pairs ############
        shuffle(real_pairs)
        fake_pairs = []
        for i in range(n_train):
            nidx = randint(0,len(train_s2_emds)-1)
            while nidx == real_pairs[i][1]:
                nidx = randint(0,len(train_s2_emds)-1)
            fake_pairs.append((real_pairs[i][0],nidx))
        
        ############ Training #############
        start = 0
        end = 0
        while start < n_train:
            end = start + batch_size
            if end > n_train:
                end = n_train
            batch_size_ = end - start
            
            inp_img = [train_s2_imgs[x[0]] for x in real_pairs[start:end]]
            inp_text = [train_s2_emds[x[1]] for x in real_pairs[start:end]]
            inp_ftext = [train_s2_emds[x[1]] for x in fake_pairs[start:end]]
            
            real_img = Variable(torch.FloatTensor(inp_img).view(-1,img_shape[0],img_shape[1],img_shape[2])).type(torch.FloatTensor).cuda()
            real_text = Variable(torch.FloatTensor(inp_text).view(-1,desc_shape[0])).cuda()
            fake_text = Variable(torch.FloatTensor(inp_ftext).view(-1,desc_shape[0])).cuda()
            
            fake_s1, _, _ = torch.nn.parallel.data_parallel(gen_s1,(real_text),gpus)
            fake_img, mu, sigma = torch.nn.parallel.data_parallel(gen,(real_text,fake_s1),gpus)
            
            disc_real = torch.nn.parallel.data_parallel(disc,(real_img, mu),gpus)
            disc_fake = torch.nn.parallel.data_parallel(disc,(fake_img, mu),gpus)
            augm = cond_aug(embedding_dim, cond_dim)
            augm.cuda()
            mu_w, _, _ = augm(fake_text)
            disc_wrong = torch.nn.parallel.data_parallel(disc,(real_img, mu_w),gpus)
            
            optim_gen.zero_grad()
            optim_disc.zero_grad()
            
            ld_ = loss_disc(disc_real, disc_fake, disc_wrong, batch_size_)         
            ld_.backward(retain_graph=True)
            optim_disc.step()
            lg_ = loss_gen(disc_fake, mu, sigma, kl_lambda, batch_size_)
            lg_.backward(retain_graph=True)
            optim_gen.step()
            
            l_tr_g_temp += float(lg_.data)
            l_tr_d_temp += float(ld_.data)
            
            del inp_img, inp_text, inp_ftext, real_img, real_text, fake_text, disc_real, disc_fake, augm, mu_w, disc_wrong, ld_, lg_, batch_size_
            
            start += batch_size
            
        l_tr_g.append(l_tr_g_temp/n_train)
        l_tr_d.append(l_tr_d_temp/n_train)
        iter_tr.append(epoch+1)
        l = open('./observations/Loss_S1_'+ts+'.txt','a+')
        l.write('Epoch ' + str(epoch) + ': Generator Loss = ' + str(l_tr_g_temp/n_train) + \
               '         Discriminator Loss = ' + str(l_tr_d_temp/n_train)  + '\n')
        print('Epoch ' + str(epoch) + ': Generator Loss = ' + str(l_tr_g_temp/n_train) + \
               '         Discriminator Loss = ' + str(l_tr_d_temp/n_train))
        l.close()
        
        if os.path.exists('saved_models/gen_s2_' + str(ts) + '.pt'):
            os.remove('saved_models/gen_s2_' + str(ts) + '.pt')
            os.remove('saved_models/disc_s2_' + str(ts) + '.pt')
        torch.save(gen,'saved_models/gen_s2_' + str(ts) + '.pt')
        torch.save(disc,'saved_models/disc_s2_' + str(ts) + '.pt')
    
    return l_tr_g, l_tr_d
        

In [None]:
def test_s1(gen, test_s1_emds, test_text, test_num = 500, random=True, model_name='A_Model_Has_No_Name', GPU_List='0', batch_size=4):
    if test_num > len(test_s1_emds):
        print("Too many test images!")
        return
    
    desc_shape = train_s1_emds[0].shape
    ts = ts_ + "_test_S1_" + model_name
    store_images_path = "./observations/" + ts + "/"
    directory = os.path.dirname(store_images_path)

    if not os.path.exists(directory):
        os.makedirs(directory)
    
    num_batches = int(test_num/batch_size)
    
    for i in range(num_batches):
        indx=0
        text = []
        indices=[]
        
        for count in range(batch_size):
            if random == True:
                indx=randint(0,len(test_s1_emds)-1)
            else:
                indx = i * batch_size + count
            text.append(test_s1_emds[indx])
            indices.append(indx)
        
        text = Variable(torch.FloatTensor(text).view(-1,desc_shape[0])).cuda()
        imgs, _, _ = gen(text)
        
        imgs = (imgs + 1.0)*127.5
        
        for count in range(batch_size):
            filename = test_text[indices[count]] + str(i * batch_size + count) + ".jpg"
            keepcharacters = (' ','.','_')
            filename = "".join(c for c in filename if c.isalnum() or c in keepcharacters).rstrip()
            if len(filename) > 250:
                filename = filename[:250]
            torchvision.utils.save_image(imgs[count].view(3,64,64), store_images_path + filename)


In [None]:
def test_s2(gen1, gen2, test_emds, test_text, test_num = 500, random=True, model_name='A_Model_Has_No_Name', GPU_List='0', batch_size=4):
    if test_num > len(test_s1_emds):
        print("Too many test images!")
        return
    
    desc_shape = train_s2_emds[0].shape
    ts = ts_ + "_test_S2_" + model_name
    store_images_path = "./observations/" + ts + "/"
    directory = os.path.dirname(store_images_path)

    if not os.path.exists(directory):
        os.makedirs(directory)
    
    num_batches = int(test_num/batch_size)
    
    for i in range(num_batches):
        indx=0
        text = []
        indices=[]
        
        for count in range(batch_size):
            if random == True:
                indx=randint(0,len(test_s1_emds)-1)
            else:
                indx = i * batch_size + count
            text.append(test_s1_emds[indx])
            indices.append(indx)
        
        text = Variable(torch.FloatTensor(text).view(-1,desc_shape[0])).cuda()
        imgs_s1, _, _ = gen1(text)
        imgs_s2, _, _ = gen2(text, imgs_s1)
        
        imgs_s2 = (imgs_s2 + 1.0)*127.5
        
        for count in range(batch_size):
            filename = test_text[indices[count]] + str(i * batch_size + count) + ".jpg"
            keepcharacters = (' ','.','_')
            filename = "".join(c for c in filename if c.isalnum() or c in keepcharacters).rstrip()
            if len(filename) > 250:
                filename = filename[:250]
            torchvision.utils.save_image(imgs_s2[count].view(3,256,256), store_images_path + filename)
        


In [None]:
dataset = 'CUB'
f_s1_name = 'proc_data_transform_stage1.pickle'
f_s2_name = 'proc_data_transform_stage2.pickle'
addr_1 = dataset + '/' + f_s1_name
addr_2 = dataset + '/' + f_s2_name
print('Data Input ==>')
if not os.path.exists(addr_1):
    train_s1_imgs, train_s1_emds, test_s1_imgs, test_s1_emds, train_text, test_text = data_transform.read_input(dataset=dataset, stage=1)
    train_s2_imgs, train_s2_emds, test_s2_imgs, test_s2_emds, train_text, test_text = data_transform.read_input(dataset=dataset, stage=2)
    with open(addr_1,'wb') as f:
        pickle.dump([train_s1_imgs, train_s1_emds, test_s1_imgs, test_s1_emds, train_text, test_text],f)
    with open(addr_2,'wb') as f:
        pickle.dump([train_s2_imgs, train_s2_emds, test_s2_imgs, test_s2_emds, train_text, test_text],f)
else:
    with open(addr_1,'rb') as f:
        train_s1_imgs, train_s1_emds, test_s1_imgs, test_s1_emds, train_text, test_text = pickle.load(f)
    with open(addr_2,'rb') as f:
        train_s2_imgs, train_s2_emds, test_s2_imgs, test_s2_emds, train_text, test_text = pickle.load(f)

print('Train_S1_Data : ',len(train_s1_imgs))
print('Test_S1_Data : ',len(test_s1_imgs))
print('Train_S2_Data : ',len(train_s2_imgs))
print('Test_S2_Data : ',len(test_s2_imgs))
print('Train_S1_Data : ',len(train_s1_emds))
print('Test_S1_Data : ',len(test_s1_emds))
print('Train_S2_Data : ',len(train_s2_emds))
print('Test_S2_Data : ',len(test_s2_emds))
print('Train_Text : ',len(train_text))
print('Test_Text : ',len(test_text))

In [None]:
batch_size = 64
epochs_s1 = 300
epochs_s2 = 300
eta_s1 = 0.0002
eta_s2 = 0.0002
opt = optim.Adam
embedding_dim = 1024
cond_dim = 128
noise_dim = 100
ups_input_dim = 1024
down_dim = 128
num_residuals = 4
kl_lambda = 2.0
model_name = 'Stg1_Only_Parallel'
GPU_List = '0'

In [None]:
gen_s1 = stage1_gen(embedding_dim=embedding_dim, cond_dim=cond_dim, noise_dim=noise_dim, ups_input_dim=ups_input_dim)
disc_s1 = stage1_disc(cond_dim=cond_dim, down_dim=down_dim)
gen_s2 = stage2_gen(down_dim=down_dim, embedding_dim=embedding_dim, cond_dim=cond_dim, num_residuals=num_residuals)
disc_s2 = stage2_disc(cond_dim=cond_dim, down_dim=down_dim)

In [None]:
l_tr_g_s1, l_tr_d_s1 = train_s1(gen_s1, disc_s1, train_s1_imgs, train_s1_emds, batch_size=batch_size,
            epochs=epochs_s1, eta=eta_s1, opt=opt, model_name=model_name, kl_lambda=kl_lambda,
            embedding_dim=embedding_dim, cond_dim=cond_dim, GPU_List=GPU_List)

In [None]:
l_tr_g_s2, l_tr_d_s2 = train_s2(gen_s2, disc_s2, gen_s1, train_s2_imgs, train_s2_emds, batch_size=batch_size,
            epochs=epochs_s1, eta=eta_s1, opt=opt, model_name=model_name, kl_lambda=kl_lambda,
            embedding_dim=embedding_dim, cond_dim=cond_dim, GPU_List=GPU_List)

In [None]:
gen = torch.load("gen_s1_2018-05-30__07h59m27s_S1_Stg1_Only_Parallel_Trial.pt")
test_s1(gen, test_s1_emds, test_text, test_num = 100, random=False)