In [None]:
import torch
from torch.autograd import Variable
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import cv2
import pickle
from random import *
import gc
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.optim as optim
from time import gmtime, strftime
import os
%matplotlib inline

## Read Data on System

In [None]:
def read_input(dataset='CUB', process='resize'):
    # Define File Destinations
    test_desc_fname = dataset+'/desc/test/char-CNN-RNN-embeddings.npy'
    train_desc_fname = dataset+'/desc/train/char-CNN-RNN-embeddings.npy'
    
    test_files_fname = dataset+'/desc/test/filenames.pickle'
    train_files_fname = dataset+'/desc/train/filenames.pickle'
    
    test_img_dir = dataset+'/images/'
    train_img_dir = dataset+'/images/'
    
    train_s1_data = []
    test_s1_data = []
    train_s2_data = []
    test_s2_data = []
    
    #Load Training Data
    print('Loading Training Data...')
    train_embed = np.load(train_desc_fname)
    train_embed_shape = train_embed.shape
    
    with open(train_files_fname,'rb') as file:
        dat = pickle.load(file)
    for i in range(len(dat)):
        img = cv2.imread(train_img_dir+dat[i]+'.jpg',1)
        if process == 'resize':
            img_s1 = cv2.resize(img,(64,64),interpolation=cv2.INTER_AREA)
            img_s2 = cv2.resize(img,(256,256),interpolation=cv2.INTER_AREA)
        else:
            pass
        img_s1 = np.transpose(img_s1,(2,0,1))
        img_s2 = np.transpose(img_s2,(2,0,1))
        for j in range(train_embed_shape[1]):
            neg_img = randint(0,train_embed_shape[0]-1)
            while neg_img == i:
                neg_img = randint(0,train_embed_shape[0]-1)
            neg_idx = randint(0,train_embed_shape[1]-1)
            train_s1_data.append((img_s1,train_embed[i,j,:],1))
            train_s1_data.append((img_s1,train_embed[neg_img,neg_idx,:],0))
            train_s2_data.append((img_s2,train_embed[i,j,:],1))
            train_s2_data.append((img_s2,train_embed[neg_img,neg_idx,:],0))
        
        if i%1000 == 0:
            print(i)
        
    #Load Testing Data
    print('Loading Testing Data...')
    test_embed = np.load(test_desc_fname)
    test_embed_shape = test_embed.shape
    
    with open(test_files_fname,'rb') as file:
        dat = pickle.load(file)
    for i in range(len(dat)):
        img = cv2.imread(test_img_dir+dat[i]+'.jpg',1)
        if process == 'resize':
            img_s1 = cv2.resize(img,(64,64),interpolation=cv2.INTER_AREA)
            img_s2 = cv2.resize(img,(256,256),interpolation=cv2.INTER_AREA)
        else:
            pass
        img_s1 = np.transpose(img_s1,(2,0,1))
        img_s2 = np.transpose(img_s2,(2,0,1))
        for j in range(test_embed_shape[1]):
            neg_img = randint(0,test_embed_shape[0]-1)
            while neg_img == i:
                neg_img = randint(0,test_embed_shape[0]-1)
            neg_idx = randint(0,test_embed_shape[1]-1)
            test_s1_data.append((img_s1,test_embed[i,j,:],1))
            test_s1_data.append((img_s1,test_embed[neg_img,neg_idx,:],0))
            test_s2_data.append((img_s2,test_embed[i,j,:],1))
            test_s2_data.append((img_s2,test_embed[neg_img,neg_idx,:],0))
            
        if i%1000 == 0:
            print(i)
    
    return train_s1_data, test_s1_data, train_s2_data, test_s2_data

In [None]:
train_s1_data, test_s1_data, train_s2_data, test_s2_data = read_input()

In [None]:
print(len(train_s1_data))
print(len(test_s1_data))
print(len(train_s2_data))
print(len(test_s2_data))

In [None]:
gc.collect()

## StackGAN Implementation

### Loss Functions

In [None]:
loss_criterion = torch.nn.BCELoss()

def KL_Loss(mu, sigma):
    kl_ele = mu.pow(2).add_(sigma.pow(2)).mul_(-1).add_(1).add_(sigma.pow(2).log_())
    kl = torch.mean(kl_ele).mul_(-0.5)
    return kl

def loss_gen(disc_fake, mu, sigma, kl_lambda):
    g_fake = -loss_criterion(disc_fake, Variable(torch.FloatTensor([0]), requires_grad=False).cuda())
    kl_loss = KL_Loss(mu, sigma)
    g_total = g_fake + (kl_lambda*kl_loss)
    return g_total
    
def loss_disc(disc_real, disc_fake, disc_wrong):
    d_real = loss_criterion(disc_real, Variable(torch.FloatTensor([1]), requires_grad=False).cuda())
    d_fake = loss_criterion(disc_fake, Variable(torch.FloatTensor([0]), requires_grad=False).cuda())
    d_wrong = loss_criterion(disc_wrong, Variable(torch.FloatTensor([0]), requires_grad=False).cuda())
    d_total = d_real + ((d_fake + d_wrong)*0.5)
    return d_total

### Model

In [None]:
def upSample(c_in, c_out):
    mod = torch.nn.Sequential(
        torch.nn.Upsample(scale_factor=2, mode='nearest'),
        torch.nn.Conv2d(c_in, c_out, kernel_size = 3, stride = 1, padding = 1, bias = False),
        torch.nn.BatchNorm2d(c_out),
        torch.nn.ReLU(inplace=True))
    return mod

In [None]:
class cond_aug(torch.nn.Module):
    def __init__(self, embedding_dim=1024, cond_dim=128):
        super(cond_aug,self).__init__()
        self.embedding_dim = embedding_dim
        self.cond_dim = cond_dim
        
        self.fc_mu = torch.nn.Linear(self.embedding_dim, self.cond_dim)
        self.fc_sigma = torch.nn.Linear(self.embedding_dim, self.cond_dim)
        
        torch.nn.init.xavier_normal_(self.fc_mu.weight)
        torch.nn.init.xavier_normal_(self.fc_sigma.weight)
        self.relu = torch.nn.ReLU(inplace=True)
        
    def forward(self, x):
        mu = self.relu(self.fc_mu(x))
        logvar = self.relu(self.fc_sigma(x))
        sigma = logvar.mul(0.5).exp_()
        dist = MultivariateNormal(torch.zeros(self.cond_dim).cuda(), torch.eye(self.cond_dim).cuda())
        eps = Variable(dist.sample()).cuda().view(1,-1)
        
        c = mu + (sigma * eps)
        return mu, sigma, c

In [None]:
class ResBlock(torch.nn.Module):
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(channel_num, channel_num, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(channel_num),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(channel_num, channel_num, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(channel_num))
        self.relu = torch.nn.ReLU(inplace=True)
        

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        out = self.relu(out)
        return out

In [None]:
class stage1_gen(torch.nn.Module):
    def __init__(self, embedding_dim=1024, cond_dim=128, noise_dim=100, ups_input_dim=1024 ):
        super(stage1_gen,self).__init__()
        self.embedding_dim = embedding_dim
        self.cond_dim = cond_dim
        self.noise_dim = noise_dim
        self.ups_input_dim = ups_input_dim
        self.conc_dim = self.cond_dim + self.noise_dim
        
        self.dist = MultivariateNormal(torch.zeros(self.noise_dim).cuda(), torch.eye(self.noise_dim).cuda())
        self.augm = cond_aug(self.embedding_dim, self.cond_dim)
        self.ups_input = torch.nn.Sequential(
                    torch.nn.Linear(self.conc_dim, self.ups_input_dim*4*4, bias=False),
#                     torch.nn.BatchNorm1d(self.ups_input_dim*4*4).cuda())
                    torch.nn.ReLU(inplace=True))
        self.upsample1 = upSample(self.ups_input_dim,self.ups_input_dim//2)     
        self.upsample2 = upSample(self.ups_input_dim//2,self.ups_input_dim//4)
        self.upsample3 = upSample(self.ups_input_dim//4,self.ups_input_dim//8)
        self.upsample4 = upSample(self.ups_input_dim//8,self.ups_input_dim//16)
        self.gen_img = torch.nn.Sequential(
                torch.nn.Conv2d(self.ups_input_dim//16, 3, kernel_size = 3, stride = 1, padding = 1, bias = False),
                torch.nn.Tanh())
        
    def forward(self, x):
        z = Variable(self.dist.sample()).cuda().view(1,-1)
        mu, sigma, c = self.augm(x)
        c = c.view(1,-1)
        inp = torch.cat((c,z),1)
        
        x = self.ups_input(inp)
        x = x.view(-1,self.ups_input_dim,4,4)
        x = self.upsample1(x)
        x = self.upsample2(x)
        x = self.upsample3(x)
        x = self.upsample4(x)
        fake_img = self.gen_img(x)
        
        return fake_img, mu, sigma

In [None]:
class stage1_disc(torch.nn.Module):
    def __init__(self, cond_dim=128, down_dim=64):
        super(stage1_disc,self).__init__()
        self.cond_dim = cond_dim
        self.down_dim = down_dim
        
        self.enc_img = torch.nn.Sequential(
            torch.nn.Conv2d(3, self.down_dim, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim, self.down_dim*2, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*2),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*2, self.down_dim*4, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*4),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*4, self.down_dim*8, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*8),
            torch.nn.LeakyReLU(0.2, inplace=True))
        
        self.get_logits = torch.nn.Sequential(
            torch.nn.Conv2d(self.down_dim*8+self.cond_dim, self.down_dim*8, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*8, 1, kernel_size = 4, stride = 4),
            torch.nn.Sigmoid())
        
    def forward(self, image, cond_vec):
        x = self.enc_img(image)
        y = cond_vec.view(-1, self.cond_dim, 1, 1)
        y = y.repeat(1,1,4,4)
        z = torch.cat((x,y),1) # N x (128 + 512) x 4 x 4
        out = self.get_logits(z).view(-1)
        
        return out

In [None]:
class stage2_gen(torch.nn.Module):
    def __init__(self, down_dim=128, embedding_dim=1024, cond_dim=128, num_residuals=4):
        super(stage2_gen,self).__init__()
        self.down_dim = down_dim
        self.embedding_dim = embedding_dim
        self.cond_dim = cond_dim
        self.downsampler = torch.nn.Sequential(
            torch.nn.Conv2d(3, self.down_dim, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(self.down_dim, self.down_dim * 2, kernel_size = 4, stride = 2, padding = 1, bias=False),
            torch.nn.BatchNorm2d(self.down_dim * 2),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(self.down_dim * 2, self.down_dim * 4, kernel_size = 4, stride = 2, padding = 1, bias=False),
            torch.nn.BatchNorm2d(self.down_dim * 4),
            torch.nn.ReLU(inplace=True))
        
        self.augm = cond_aug(self.embedding_dim, self.cond_dim)
        
        self.joint_proc = torch.nn.Sequential(
            torch.nn.Conv2d(self.down_dim * 4 + self.cond_dim, self.down_dim * 4, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim * 4),
            torch.nn.ReLU(True))
        
        self.layers = []
        for i in range(num_residuals):
            self.layers.append(ResBlock(self.down_dim * 4))
        self.residual = torch.nn.Sequential(*self.layers)
        
        self.upsample1 = upSample(self.down_dim * 4, self.down_dim * 2)
        self.upsample2 = upSample(self.down_dim * 2, self.down_dim)
        self.upsample3 = upSample(self.down_dim, self.down_dim // 2)
        self.upsample4 = upSample(self.down_dim // 2, self.down_dim // 4)
        self.gen_img = torch.nn.Sequential(
            torch.nn.Conv2d(self.down_dim // 4, 3, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.Tanh())

        
        
    def forward(self, text_embedding, stage1_image):
        encoded_img = self.downsampler(stage1_image) # --> N x 512 x 16 x 16
        mu, sigma, c = self.augm(text_embedding)
        c = c.view(-1, self.cond_dim, 1, 1) # --> N x 128 x 1 x 1
        c = c.repeat(1, 1, 16, 16) # --> N x 128 x 16 x 16
        conc_inp = torch.cat([encoded_img, c], 1) # --> N x 640 x 16 x 16
        
        conc_out = self.joint_proc(conc_inp) # --> N x 512 x 16 x 16
        conc_out = self.residual(conc_out)   # --> N x 512 x 16 x 16
        
        conc_out = self.upsample1(conc_out) # --> N x 256 x 32 x 32
        conc_out = self.upsample2(conc_out) # --> N x 128 x 64 x 64
        conc_out = self.upsample3(conc_out) # --> N x 64 x 128 x 128
        conc_out = self.upsample4(conc_out) # --> N x 32 x 256 x 256

        fake_img = self.gen_img(conc_out) # --> N x 3 x 256 x 256
        return fake_img, mu, sigma

In [None]:
class stage2_disc(torch.nn.Module):
    def __init__(self, cond_dim=128, down_dim=64):
        super(stage2_disc,self).__init__()
        self.cond_dim = cond_dim
        self.down_dim = down_dim
        
        self.enc_img = torch.nn.Sequential(
            torch.nn.Conv2d(3, self.down_dim, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim, self.down_dim*2, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*2),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*2, self.down_dim*4, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*4),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*4, self.down_dim*8, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*8, self.down_dim*16, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*16),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*16, self.down_dim*32, kernel_size = 4, stride = 2, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*32),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*32, self.down_dim*16, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim * 16),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*16, self.down_dim*8, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim * 8),
            torch.nn.LeakyReLU(0.2, inplace=True))   # 4 * 4 * ndf * 8)
        
        self.get_logits = torch.nn.Sequential(
            torch.nn.Conv2d(self.down_dim*8+self.cond_dim, self.down_dim*8, kernel_size = 3, stride = 1, padding = 1, bias = False),
            torch.nn.BatchNorm2d(self.down_dim*8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(self.down_dim*8, 1, kernel_size = 4, stride = 4),
            torch.nn.Sigmoid())
        
    def forward(self, image, cond_vec):
        x = self.enc_img(image)
        y = cond_vec.view(-1, self.cond_dim, 1, 1)
        y = y.repeat(1,1,4,4)
#         print(x.shape)
#         print(y.shape)
        z = torch.cat((x,y),1) # N x (128 + 512) x 4 x 4
        out = self.get_logits(z).view(-1)
        
        return out

In [None]:
def train_s1(gen, disc, train_s1_data, epochs=1000, eta=0.0001, opt=optim.Adam, model_name='A_Model_Has_No_Name',
             kl_lambda=2.0, embedding_dim=1024, cond_dim=128):
    gen = gen.cuda()
    disc.cuda()
    
    optim_gen = opt(gen.parameters(), lr=eta)
    optim_disc = opt(disc.parameters(), lr=eta)
        
    l_tr_g = []
    l_tr_d = []
    iter_tr = []
    
    n_train = len(train_s1_data)
    n_test = len(test_s1_data)
    desc_shape = train_s1_data[0][1].shape
    img_shape = train_s1_data[0][0].shape
    
    file_path = "./observations/"
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    file_path = "./saved_models/"
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    ts = strftime("%Y-%m-%d__%Hh%Mm%Ss_S1_" + model_name, gmtime())
    
    for epoch in range(epochs):
        l_tr_g_temp = 0
        l_tr_d_temp = 0
        
        ############ Training #############
        for i in range(3):
            print(i, '-------')
            orig_img = Variable(torch.from_numpy(train_s1_data[i][0]).view(-1,img_shape[0],img_shape[1],img_shape[2])).type(torch.FloatTensor).cuda()
            text = Variable(torch.from_numpy(train_s1_data[i][1]).view(-1,desc_shape[0])).cuda()
            
            fake_img, mu, sigma = gen(text)
            
            disc_real = disc(orig_img, mu)
            disc_fake = disc(fake_img, mu)
            augm = cond_aug(embedding_dim, cond_dim)
            augm.cuda()
            mu_w, _, _ = augm(Variable(torch.from_numpy(train_s1_data[i+1][1]).view(-1,desc_shape[0])).cuda())
            disc_wrong = disc(orig_img, mu_w)
            
            optim_gen.zero_grad()
            optim_disc.zero_grad()
            
            ld_ = loss_disc(disc_real, disc_fake, disc_wrong)         
            ld_.backward(retain_graph=True)
            optim_disc.step()
            lg_ = loss_gen(disc_fake, mu, sigma, kl_lambda)
            lg_.backward(retain_graph=True)
            optim_gen.step()
            
            l_tr_g_temp += float(lg_.data)
            l_tr_d_temp += float(ld_.data)
            i+=1
            del orig_img, text, fake_img, mu, sigma, augm, disc_real, disc_wrong, disc_fake, mu_w, lg_, ld_
            
        l_tr_g.append(l_tr_g_temp/n_train)
        l_tr_d.append(l_tr_d_temp/n_train)
        iter_tr.append(epoch+1)
        l = open('./observations/Loss_S1_'+ts+'.txt','a+')
        l.write('Epoch ' + str(epoch) + ': Generator Loss = ' + str(l_tr_g_temp/n_train) + \
               '         Discriminator Loss = ' + str(l_tr_d_temp/n_train)  + '\n')
        print('Epoch ' + str(epoch) + ': Generator Loss = ' + str(l_tr_g_temp/n_train) + \
               '         Discriminator Loss = ' + str(l_tr_d_temp/n_train))
        l.close()
    
    return l_tr_g, l_tr_d
        

In [None]:
def train_s2(gen, disc, trained_gen1, train_s2_data, epochs=1000, eta=0.0001, opt=optim.Adam, model_name='A_Model_Has_No_Name',
             kl_lambda=2.0, embedding_dim=1024, cond_dim=128):
    gen.cuda()
    disc.cuda()
    trained_gen1.cuda()
    optim_gen = opt(gen.parameters(), lr=eta)
    optim_disc = opt(disc.parameters(), lr=eta)
    
    fake_s1 = []
    l_tr_g = []
    l_tr_d = []
    iter_tr = []
    
    n_train = len(train_s1_data)
    n_test = len(test_s1_data)
    desc_shape = train_s2_data[0][1].shape
    img_shape = train_s2_data[0][0].shape
    
#     for i in range(len(train_s2_data)):
#         text = Variable(torch.from_numpy(train_s2_data[i][1]).view(-1,desc_shape[0]))
#         fake_s1.append(trained_gen1(text)[0])
    
    file_path = "./observations/"
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    file_path = "./saved_models/"
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    ts = strftime("%Y-%m-%d__%Hh%Mm%Ss_S2_" + model_name, gmtime())
    
    for epoch in range(epochs):
        l_tr_g_temp = 0
        l_tr_d_temp = 0
        
        ############ Training #############
        for i in range(3):
            print(i, '-------')
            orig_img = Variable(torch.from_numpy(train_s2_data[i][0]).view(-1,img_shape[0],img_shape[1],img_shape[2])).type(torch.FloatTensor).cuda()
            
            text = Variable(torch.from_numpy(train_s2_data[i][1]).view(-1,desc_shape[0])).cuda()
            fake_s1 = trained_gen1(text)[0]
            fake_img, mu, sigma = gen(text, fake_s1)
            
            disc_real = disc(orig_img, mu)
            disc_fake = disc(fake_img, mu)
            augm = cond_aug(embedding_dim, cond_dim)
            augm.cuda()
            mu_w, _, _ = augm(Variable(torch.from_numpy(train_s2_data[i+1][1]).view(-1,desc_shape[0])).cuda())
            disc_wrong = disc(orig_img, mu_w)
            
            optim_gen.zero_grad()
            optim_disc.zero_grad()
            
            ld_ = loss_disc(disc_real, disc_fake, disc_wrong)         
            ld_.backward(retain_graph=True)
            optim_disc.step()
            
            lg_ = loss_gen(disc_fake, mu, sigma, kl_lambda)
            lg_.backward(retain_graph=True)
            optim_gen.step()
            
            l_tr_g_temp += float(lg_.data)
            l_tr_d_temp += float(ld_.data)
            i+=1
            del orig_img, text, fake_s1, fake_img, mu, sigma, augm, disc_real, disc_wrong, disc_fake, mu_w, lg_, ld_
            
        l_tr_g.append(l_tr_g_temp/n_train)
        l_tr_d.append(l_tr_d_temp/n_train)
        iter_tr.append(epoch+1)
        l = open('./observations/Loss_S2_'+ts+'.txt','a+')
        l.write('Epoch ' + str(epoch) + ': Generator Loss = ' + str(l_tr_g_temp/n_train) + \
               '         Discriminator Loss = ' + str(l_tr_d_temp/n_train) + '\n')
        print('Epoch ' + str(epoch) + ': Generator Loss = ' + str(l_tr_g_temp/n_train) + \
               '         Discriminator Loss = ' + str(l_tr_d_temp/n_train))
        l.close()
    
    return l_tr_g, l_tr_d
        

In [None]:
gen = stage1_gen(embedding_dim=1024, cond_dim=128, noise_dim=100, ups_input_dim=1024)
# gen = stage1_gen(down_dim=128, embedding_dim=1024, cond_dim=128, num_residuals=4)
disc = stage1_disc(cond_dim=128, down_dim=64)
train_s1(gen, disc, train_s1_data, epochs=1000, eta=0.0001, opt=optim.Adam, model_name='A_Model_Has_No_Name',
             kl_lambda=2.0, embedding_dim=1024, cond_dim=128)