In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from net_sphere import sphere20a
#refernece : https://github.com/davrempe/domain-transfer-net/blob/master/FaceMain.ipynb
import os
from torch.utils.data import Dataset, DataLoader
import csv
from PIL import Image
import numpy as np
import torch
import torchvision
from torch.autograd import Variable
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from collections import OrderedDict
# import

In [2]:
class ResizeTransform(object):
    ''' Resizes a PIL image to (size, size) to feed into OpenFace net and returns a torch tensor.'''
    def __init__(self, size):
        self.size = size
        
    def __call__(self, sample):
        img = sample.resize((self.size, self.size), Image.BILINEAR)
        img = np.transpose(img, (2, 0, 1))
        img = img.astype(np.float32) / 255.0
        return torch.from_numpy(img)


class BitEmoji(Dataset):
    '''
    Dataset of 1 million bitmoji images.
    start_idx - image number dataset should start at
    end_idx - data number where dataset ends
    '''
    def __init__(self, data_dir, start_index=0, end_index=100000, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.data_len = end_index - start_index
    
    def __getitem__(self, idx):
        """
        Args:
            index (int): Index
        """
        img_name = os.path.join(self.data_dir, 'emoji_{}.png'.format(idx))
        img = Image.open(img_name)
        img = img.convert('RGB') # as it's a png

        if self.transform is not None:
            img = self.transform(img)
                                   
        return img

    def __len__(self):
        return self.data_len    
    
class MSCelebDataset(Dataset):
    '''
    MS-Celeb-1M face image dataset. This is the aligned and cropped version. 
    data_dir - directory of data. This directory should contain annotation files and a subdirectory for image data.
    split - either 'train' or 'test'
    '''
    def __init__(self, data_dir, split, transform=None):
        data_splits = ['train', 'test']
        self.transform = transform
        
        split = data_splits.index(split)
        if split == 0:
            info_path = 'train_data_info.txt'
            self.data_path = os.path.join(data_dir, 'images_train/')
        elif split == 1:
            info_path = 'info/test_data_info.txt'
            self.data_path = os.path.join(data_dir, 'data/')
        
        info_data = []
        with open(os.path.join(data_dir, info_path)) as info_file:
            reader = csv.reader(info_file, delimiter=' ')
            for row in reader:
                info_data.append(row)
                
        self.info = np.array(info_data)
        self.data_len = self.info.shape[0]

    def __getitem__(self, idx):
        """
        Args:
            index (int): Index
        """
        img_name = os.path.join(self.data_path, self.info[idx, 0])
        img = Image.open(img_name)
        
        if self.transform is not None:
            img = self.transform(img)
                       
        return img

    def __len__(self):
        return self.data_len

In [3]:
# self.target_path='/home/jupyter/emoji_data'
# self.train_set_bitmoji = BitEmoji(a, 0, 100000, transform = ResizeTransform(96))
# train_loader_bitmoji = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
# source_path = '/home/jupyter/MSceleb_data/data/'
# train_set_celeb = MSCelebDataset(b, 'test', ResizeTransform(96))
# train_loader_celeb = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)

In [4]:
class G(nn.Module):
    def __init__(self, channels):
        super(self.__class__, self).__init__()
        self.channels = channels
        self.g = nn.Sequential(
                    nn.ConvTranspose2d(self.channels, 512, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(512),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(512),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(256),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(256),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(128),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(128),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(64),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(64),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(32),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(32, 32, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(32),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
                    nn.Tanh())
        
    def forward(self, inp):
        resh_inp = inp.view(inp.size()[0], 512, 1, 1)
        out = self.g(resh_inp)
        return out

In [5]:
class D(nn.Module):
    def __init__(self, channels, al=0.2):
        super(self.__class__, self).__init__()
        self.channels = channels
        self.al = al
        self.d = nn.Sequential(
                    nn.Conv2d(3, self.channels, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels),
                    nn.LeakyReLU(self.al, inplace=True),
        
                    nn.Conv2d(self.channels, self.channels*2, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels*2),
                    nn.LeakyReLU(self.al, inplace=True),
        
                    nn.Conv2d(self.channels*2, self.channels*4, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels*4),
                    nn.LeakyReLU(self.al, inplace=True),
        
                    nn.Conv2d(self.channels*4, self.channels*2, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels*2),
                    nn.LeakyReLU(self.al, inplace=True),
        
                    nn.Conv2d(self.channels*2, self.channels, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels),
                    nn.LeakyReLU(self.al, inplace=True),
                    
                    nn.Conv2d(self.channels, 3, kernel_size=3, stride=1)
        )
        
    def forward(self, inp):
        out = self.d(inp)
        return out

In [6]:
dis = D(128)

In [7]:
class ZeroPadBottom(object):
    ''' Zero pads batch of image tensor Variables on bottom to given size. Input (B, C, H, W) - padded on H axis. '''
    def __init__(self, size, use_gpu=True):
        self.size = size
        self.use_gpu = use_gpu
        
    def __call__(self, sample):
        B, C, H, W = sample.size()
        diff = self.size - H
        padding = Variable(torch.zeros(B, C, diff, W), requires_grad=False)
        if self.use_gpu:
            padding = padding.cuda()
        zero_padded = torch.cat((sample, padding), dim=2)
        return zero_padded

In [8]:
l_g = []
l_d = []

In [9]:

class tanh_unnormalize(object):
    def __init__(self):
        pass
  
    def __call__(self, sample):
        sample = sample * 0.5 + 0.5
        return sample
        
class tanh_normalize(object):
    ''' Normalizes a tensor with values from [0, 1] to [-1, 1]. '''
    def __init__(self):
        pass
    
    def __call__(self, sample):
        sample = sample * 2.0 - 1.0
        return sample
    
class face_model():
    def __init__(self, use_gpu=True):
        super(face_model, self).__init__()
        self.use_gpu = use_gpu
        self.generator_loss_func = None
        self.discriminator_loss_func = None
        self.gan_loss_func = None
        self.generator_smooth_func = None
        self.source_val_loader = None
        self.source_test_loader = None
        self.target_test_loader = None
        self.source_train_loader = None
        self.target_train_loader = None
        self.batch_size = 128
        self.lossCE = nn.CrossEntropyLoss()
        self.target_path='/home/jupyter/emoji_data'
        self.source_path ='/home/jupyter/MSceleb_data/'
        

    def make_loader(self):
        '''TO DO'''
        msface_transform = transforms.Compose(
        [ResizeTransform(96),tanh_normalize()])
        emoji_transform = transforms.Compose(
        [ResizeTransform(96),tanh_normalize()])
        source_train_set = MSCelebDataset(self.source_path, 'test',transform=msface_transform)
        self.source_train_loader = torch.utils.data.DataLoader(source_train_set, batch_size=128,shuffle=True, num_workers=8)
        
        target_train_set = BitEmoji(self.target_path, 0,100000, transform = emoji_transform)
        self.target_train_loader = torch.utils.data.DataLoader(target_train_set, batch_size=  128, shuffle=True, num_workers=8)
        
        """
            TEST SET CHUTIYA KAREGA
        """
    
    def make_model(self):
        
        self.model = {}
        f = sphere20a(feature = True)
        f.load_state_dict(torch.load('./sphere20a_20171020.pth'))
        for params in f.parameters():
            params.require_grad = False
        
        self.model['f'] = f
        self.model['g'] = G(channels=512)
        self.model['d'] = D(128)
        
        if self.use_gpu:
            self.model['g'] = self.model['g'].cuda()
            self.model['d'] = self.model['d'].cuda()
            self.model['f'] = self.model['f'].cuda()
            
        self.up = nn.Upsample(size=(96,96), mode='bilinear')
        self.pad = ZeroPadBottom(112)
        
    def make_loss_func(self):
        
        self.lossCE = nn.CrossEntropyLoss().cuda()
        self.lossMSE = nn.MSELoss().cuda()
        lab0, lab1, lab2 = (torch.LongTensor(self.batch_size) for i in range(3))
        lab0 = Variable(lab0.cuda())
        lab1 = Variable(lab1.cuda())
        lab2 = Variable(lab2.cuda())
        
        lab0.data.resize_(self.batch_size).fill_(0)
        lab1.data.resize_(self.batch_size).fill_(1)
        lab2.data.resize_(self.batch_size).fill_(2)
        
        self.lab0 = lab0
        self.lab1 = lab1
        self.lab2 = lab2
        
        self.make_generator_loss_func()
        self.make_discriminator_loss_func()
        self.make_smooth_func()
        #self.make_dist_func_targ_domain()
    
    def make_opt(self):
        
        self.generator_opt = optim.Adam(self.model['g'].parameters(), lr = 2e-4, betas=(0.5, 0.999), weight_decay=1e-6)
        self.discriminator_opt = optim.Adam(self.model['d'].parameters(), lr = 2e-4, betas=(0.5, 0.999), weight_decay=1e-6)
        
        self.generator_lr_sche = MultiStepLR(self.generator_opt, milestones=[15000], gamma=0.1)
        self.discriminator_lr_sche = MultiStepLR(self.discriminator_opt, milestones=[15000], gamma=0.1)
    
    def cos_sim(self, x, y):
        ab = torch.sum(x*y, dim=1)
        a = torch.sqrt(torch.sum(x*x, dim=1))
        b = torch.sqrt(torch.sum(y*y, dim=1))
        sim = ab/(a*b)
        avg_sim = torch.mean(sim)
        cos_loss = 1-avg_sim
        return cos_loss
        
    def make_smooth_func(self):
        
        def gen_smooth_func(s_g):
            b,c,h,w = s_g.size()
            
            g_t = s_g.contiguous().view(b,c,h,w)
            z_d = g_t[:,:,1:,:-1]
            z_u = g_t[:,:,-1:,1:]
            z = g_t[:,:,:-1,:-1]
            
            diff_sum = torch.abs(z_d-z) + torch.abs(z_u-z)
            
            loss = torch.mean(torch.sum(torch.sum(torch.mean(diff_sum,dim=1),dim=1),dim=1))
            return loss
        
        self.gen_smooth_func = gen_smooth_func
    
    def make_generator_loss_func(self):
        
        def gloss(s_g, s_f, s_g_f, s_d_g, t, t_g, t_d_g, al=0.01, be=100, gam=0.0001):
            l_gang = self.lossCE(s_d_g.squeeze(), self.lab2)+self.lossCE(t_d_g.squeeze(), self.lab2)
            l_const = self.cos_sim(s_f.detach(), s_g_f)
            ltv = self.gen_smooth_func(s_g)
            ltid = self.lossMSE(t_g, t.detach())
            
            return l_gang + al*l_const + be*ltid + gam*ltv
        
        self.generator_loss_func = gloss
        
    def make_discriminator_loss_func(self):
        
        def dloss(s_d_g, t_d_g, t_d):
            return self.lossCE(s_d_g.squeeze(), self.lab0) + self.lossCE(t_d_g.squeeze(), self.lab1) + self.lossCE(t_d.squeeze(), self.lab2)
        
        self.discriminator_loss_func = dloss
    
    def seeResultsSrc(self, s_data, s_G, t):     
        s_data = s_data.cpu().data
        s_G = s_G.cpu().data
        t = t.cpu().data
                
        # Unnormalize images
        unnorm_ms = tanh_unnormalize() #((0.5,0.5,0.5), (0.5,0.5,0.5))
        unnorm_emoji = tanh_unnormalize() #((0.2411, 0.1801, 0.1247), (0.3312, 0.2672, 0.2127))
        unnorm_targ = tanh_unnormalize()
        self.imshow(unnorm_ms(s_data[:16]))
        self.imshow(unnorm_emoji(s_G[:16]))
        self.imshow(unnorm_emoji(t[:16]))
    
    def imshow(self, img):
        plt.figure()
        npimg = torchvision.utils.make_grid(img, nrow=4).numpy()
        npimg = np.transpose(npimg, (1, 2, 0)) 
        zero_array = np.zeros(npimg.shape)
        one_array = np.ones(npimg.shape)
        npimg = np.minimum(npimg,one_array)
        npimg = np.maximum(npimg,zero_array)
        
        plt.imshow(npimg)
        plt.show()
    
    def train(self, n_epo, **kwargs):
        
        l = min(len(self.source_train_loader), len(self.target_train_loader))
        msimg_count = 0
        tot_batch = 0
        visualize_batches = 100 #kwargs.get("visualize_batches", 50)
        for e in range(n_epo):
            print("Epoch", e)
            source_data_iter = iter(self.source_train_loader)
            target_data_iter = iter(self.target_train_loader)
            
            for i in range(l):
#                 print(i)
                self.generator_lr_sche.step()
                self.discriminator_lr_sche.step()
                
                msimg_count+=1
                
                if msimg_count >= len(self.source_train_loader):
                    msimg_count = 0
                    source_data_iter = iter(self.source_train_loader)
                    
                source_data = source_data_iter.next()
                target_data = target_data_iter.next()
                
                if self.batch_size != source_data.size(0) or self.batch_size != target_data.size(0):
                    continue
                
                tot_batch+=1
                
                if self.use_gpu:
                    source_data = Variable(source_data.float().cuda())
                    target_data = Variable(target_data.float().cuda())
                else:
                    source_data = Variable(source_data.float())
                    target_data = Variable(target_data.float())
                    
                for param in self.model['d'].parameters():
                    param.requires_grad = True
                self.model['d'].zero_grad()
                
                source_data_pad = self.pad(source_data)
                s_f = self.model['f'](source_data_pad)
                s_g = self.model['g'](s_f)
                s_g = self.up(s_g)
#                 s_g_detach = s_g.detach()
                s_d_g = self.model['d'](s_g)
                
                target_data_pad = self.pad(target_data)
                t_f = self.model['f'](target_data_pad)
                t_g = self.model['g'](t_f)
                t_g = self.up(t_g)
#                 t_g_detach = t_g.detach()
                t_d_g = self.model['d'](t_g)
                
                t_d = self.model['d'](target_data)
                
                d_loss = self.discriminator_loss_func(s_d_g, t_d_g, t_d)
                d_loss.backward()
                self.discriminator_opt.step()
                
                for param in self.model['d'].parameters():
                    param.requires_grad = False
                self.model['g'].zero_grad()
                
                source_data_pad = self.pad(source_data)
                s_f = self.model['f'](source_data_pad)
                s_g = self.model['g'](s_f)
                s_g = self.up(s_g)
#                 s_g_detach = s_g.detach()
                s_d_g = self.model['d'](s_g)
                s_g_pad = self.pad(s_g)
                s_g_f = self.model['f'](s_g_pad)
                
                target_data_pad = self.pad(target_data)
                t_f = self.model['f'](target_data_pad)
                t_g = self.model['g'](t_f)
                t_g = self.up(t_g)
#                 t_g_detach = t_g.detach()
                t_d_g = self.model['d'](t_g)
                
                g_loss = self.generator_loss_func(s_g, s_f, s_g_f, s_d_g, target_data, t_g, t_d_g)
                g_loss.backward()
                self.generator_opt.step()
                l_d.append(d_loss.data.item())
                l_g.append(g_loss.data.item())

                if i%visualize_batches == 0 :
                    source_data_padded = self.pad(source_data)
                    s_f = self.model['f'](source_data_padded)
                    s_g = self.model['g'](s_f)
                    # upscale
                    s_g = self.up(s_g) 
                    self.seeResultsSrc(source_data, s_g, target_data)
                

In [10]:
model = face_model()

In [11]:
model.make_model()

In [12]:
model.make_loader()

In [13]:
model.make_loss_func()

In [14]:
model.make_opt()

In [15]:
model.train(5)

Epoch 0


FileNotFoundError: Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/jupyter/.local/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 138, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "<ipython-input-2-97bce4a37eb1>", line 30, in __getitem__
    img = Image.open(img_name)
  File "/home/jupyter/.local/lib/python3.5/site-packages/PIL/Image.py", line 2634, in open
    fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/home/jupyter/emoji_data/emoji_85536.png'


In [None]:
f = sphere20a(feature = True)
f.load_state_dict(torch.load('./sphere20a_20171020.pth'))

In [109]:
f

sphere20a(
  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (relu1_1): PReLU(num_parameters=64)
  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_2): PReLU(num_parameters=64)
  (conv1_3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_3): PReLU(num_parameters=64)
  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (relu2_1): PReLU(num_parameters=128)
  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_2): PReLU(num_parameters=128)
  (conv2_3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_3): PReLU(num_parameters=128)
  (conv2_4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_4): PReLU(num_parameters=128)
  (conv2_5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_5): PReLU(num_parameters=128)
  (conv3_1): Conv2d(128, 256, kernel_siz

In [68]:
gener = {'state_dict':OrderedDict()}
gener['state_dict']=model.model['g'].state_dict
disc = {'state_dict':OrderedDict()}
disc['state_dict']=model.model['d'].state_dict
torch.save(gener['state_dict'],'generator_al1e-2_be100_gam_1e-4.tar')
torch.save(disc['state_dict'],'discriminator_al1e-2_be100_gam_1e-4.tar')