In [12]:
import torch
from torch.autograd import Variable
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import cv2
import pickle
from random import *
import gc
from torch.distributions.multivariate_normal import MultivariateNormal
%matplotlib inline

## Read Data on System

In [2]:
def read_input(dataset='CUB', process='resize'):
    # Define File Destinations
    test_desc_fname = dataset+'/desc/test/char-CNN-RNN-embeddings.npy'
    train_desc_fname = dataset+'/desc/train/char-CNN-RNN-embeddings.npy'
    
    test_files_fname = dataset+'/desc/test/filenames.pickle'
    train_files_fname = dataset+'/desc/train/filenames.pickle'
    
    test_img_dir = dataset+'/images/'
    train_img_dir = dataset+'/images/'
    
    train_s1_data = []
    test_s1_data = []
    train_s2_data = []
    test_s2_data = []
    
    #Load Training Data
    print('Loading Training Data...')
    train_embed = np.load(train_desc_fname)
    train_embed_shape = train_embed.shape
    
    with open(train_files_fname,'rb') as file:
        dat = pickle.load(file)
    for i in range(len(dat)):
        img = cv2.imread(train_img_dir+dat[i]+'.jpg',1)
        if process == 'resize':
            img_s1 = cv2.resize(img,(64,64),interpolation=cv2.INTER_AREA)
            img_s2 = cv2.resize(img,(256,256),interpolation=cv2.INTER_AREA)
        else:
            pass
        for j in range(train_embed_shape[1]):
            train_s1_data.append((img_s1,train_embed[i,j,:],1))
            
            neg_img = randint(0,train_embed_shape[0]-1)
   
            while neg_img == i:
                neg_img = randint(0,train_embed_shape[0]-1)
            neg_idx = randint(0,train_embed_shape[1]-1)
            train_s2_data.append((img_s2,train_embed[i,j,:],1))
            train_s2_data.append((img_s2,train_embed[neg_img,neg_idx,:],0))
        
        if i%1000 == 0:
            print(i)
        
    #Load Testing Data
    print('Loading Testing Data...')
    test_embed = np.load(test_desc_fname)
    test_embed_shape = test_embed.shape
    
    with open(test_files_fname,'rb') as file:
        dat = pickle.load(file)
    for i in range(len(dat)):
        img = cv2.imread(test_img_dir+dat[i]+'.jpg',1)
        if process == 'resize':
            img_s1 = cv2.resize(img,(64,64),interpolation=cv2.INTER_AREA)
            img_s2 = cv2.resize(img,(256,256),interpolation=cv2.INTER_AREA)
        else:
            pass
        for j in range(test_embed_shape[1]):
            test_s1_data.append((img_s1,test_embed[i,j,:],1))
            
            neg_img = randint(0,test_embed_shape[0]-1)
            while neg_img == i:
                neg_img = randint(0,test_embed_shape[0]-1)
            neg_idx = randint(0,test_embed_shape[1]-1)
            test_s2_data.append((img_s2,test_embed[i,j,:],1))
            test_s2_data.append((img_s2,test_embed[neg_img,neg_idx,:],0))
            
        if i%1000 == 0:
            print(i)
    
    return train_s1_data, test_s1_data, train_s2_data, test_s2_data

In [3]:
train_s1_data, test_s1_data, train_s2_data, test_s2_data = read_input()

Loading Training Data...
0
1000
2000
3000
4000
5000
6000
7000
8000
Loading Testing Data...
0
1000
2000


In [4]:
print(len(train_s1_data))
print(len(test_s1_data))
print(len(train_s2_data))
print(len(test_s2_data))

88550
29330
177100
58660


In [5]:
gc.collect()

11

## StackGAN Implementation

### Loss Functions

In [6]:
def loss_gen():
    pass
    
def loss_disc():
    pass


### Model

In [9]:
def upSample(c_in, c_out):
    mod = torch.nn.Sequential(
        torch.nn.Upsample(scale_factor=2, mode='nearest'),
        torch.nn.Conv2d(c_in, c_out, kernel_size = 3, stride = 1, padding = 1, bias = False),
        torch.nn.BatchNorm2d(c_out),
        torch.nn.ReLU(inplace=True))

In [14]:
class cond_aug(torch.nn.Module):
    def __init__(self, embedding_dim=1024, cond_dim=128):
        super(cond_aug,self).__init__()
        self.embedding_dim = embedding_dim
        self.cond_dim = cond_dim
        
        self.fc_mu = torch.nn.Linear(self.embedding_dim, self.cond_dim)
        self.fc_sigma = torch.nn.Linear(self.embedding_dim, self.cond_dim)
        
        torch.nn.init.xavier_normal(self.fc_mu.weight)
        torch.nn.init.xavier_normal(self.fc_sigma.weight)
        
    def forward(x):
        mu = torch.nn.ReLU(self.fc_mu(x))
        sigma = torch.nn.ReLU(self.fc_sigma(x))
        dist = MultivariateNormal(torch.zeros(self.cond_dim), torch.eye(self.cond_dim))
        eps = Variable(dist.sample()).view(1,-1)
        
        c = mu + (sigma * eps)
        return mu, sigma, c

In [16]:
class stage1_gen(torch.nn.Module):
    def __init__(self, embedding_dim=1024, cond_dim=128, noise_dim=100, ups_input_dim=1024 ):
        super(stage1_gen,self).__init__()
        self.embedding_dim = embedding_dim
        self.cond_dim = cond_dim
        self.noise_dim = noise_dim
        self.ups_input_dim = ups_input_dim
        self.conc_dim = self.cond_dim + self.noise_dim
        
        self.dist = MultivariateNormal(torch.zeros(self.noise_dim), torch.eye(self.noise_dim))
        self.augm = cond_aug(self.embedding_dim, self.cond_dim)
        self.ups_input = torch.nn.Sequential(
                    torch.nn.Linear(self.conc_dim, self.ups_input_dim*4*4, bias=False),
                    torch.nn.BatchNorm1d(self.ups_input_dim*4*4),
                    torch.nn.ReLU(inplace=True))
        self.upsample1 = upSample(self.ups_input_dim,self.ups_input_dim//2)     
        self.upsample2 = upSample(self.ups_input_dim//2,self.ups_input_dim//4)
        self.upsample3 = upSample(self.ups_input_dim//4,self.ups_input_dim//8)
        self.upsample4 = upSample(self.ups_input_dim//8,self.ups_input_dim//16)
        self.gen_img = torch.nn.Sequential(
                torch.nn.Conv2d(self.ups_input_dim//16, 3, kernel_size = 3, stride = 1, padding = 1, bias = False),
                torch.nn.Tanh())
        
    def forward(x):
        z = Variable(self.dist.sample()).view(1,-1)
        mu, sigma, c = self.augm(x).view(1,-1)
        inp = torch.cat((c,z),1)
        
        x = self.ups_input(inp)
        x = self.upsample1(x)
        x = self.upsample2(x)
        x = self.upsample3(x)
        x = self.upsample4(x)
        out = self.gen_img(x)
        
        return mu, sigma, out