In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR

In [2]:
class G(nn.Module):
    def __init__(self, channels):
        super(self.__class__, self).__init()
        self.channels = channels
        self.g = nn.Sequential(
                    nn.ConvTranspose2d(self.channels, 512, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(512),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(512),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(256),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(256),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(128),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(128),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(64),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(64),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(32),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(32, 32, kernel_size=1, stride=1, padding=0),
                    nn.BatchNorm2d(32),
                    nn.LeakyReLU(0.2, inplace=True),
        
                    nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
                    nn.Tanh())
        
    def forward(self, inp):
        resh_inp = inp.view(inp.size()[0], 512, 1, 1)
        out = self.g(resh_inp)
        return out

In [3]:
class D(nn.Module):
    def __init__(self, channels, al=0.2):
        self.channels = channels
        self.al = al
        self.d = nn.Sequential(
                    nn.Conv2d(3, self.channels, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels),
                    nn.LeakyReLU(self.alpha, inplace=True),
        
                    nn.Conv2d(self.channels, self.channels*2, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels),
                    nn.LeakyReLU(self.alpha, inplace=True),
        
                    nn.Conv2d(self.channels*2, self.channels*4, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels),
                    nn.LeakyReLU(self.alpha, inplace=True),
        
                    nn.Conv2d(self.channels*4, self.channels*2, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels),
                    nn.LeakyReLU(self.alpha, inplace=True),
        
                    nn.Conv2d(self.channels*2, self.channels, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.channels),
                    nn.LeakyReLU(self.alpha, inplace=True),
                    
                    nn.Conv2d(self.channels, 3, kernel_size=3, stride=1))
        
    def forward(self, inp):
        out = self.d(inp)
        return out

In [4]:
class ZeroPadBottom(object):
    ''' Zero pads batch of image tensor Variables on bottom to given size. Input (B, C, H, W) - padded on H axis. '''
    def __init__(self, size, use_gpu=True):
        self.size = size
        self.use_gpu = use_gpu
        
    def __call__(self, sample):
        B, C, H, W = sample.size()
        diff = self.size - H
        padding = Variable(torch.zeros(B, C, diff, W), requires_grad=False)
        if self.use_gpu:
            padding = padding.cuda()
        zero_padded = torch.cat((sample, padding), dim=2)
        return zero_padded

In [5]:
class face_model(object):
    def __init__(self, use_gpu=True):
        super(face_model, self).__init__(use_gpu)
        self.generator_loss_func = None
        self.discriminator_loss_func = None
        self.gan_loss_func = None
        self.generator_smooth_func = None
        self.source_val_loader = None
        self.source_test_loader = None
        self.target_test_loader = None
        self.source_train_loader = None
        self.target_train_loader = None
        self.batch_size = 128
        self.lossCE = nn.CrossEntropyLoss()
        
    def make_loader(self):
        '''TO DO'''
    
    def make_model(self):
        
        self.model = {}
        f = sphere20a(feature = True)
        f.load_state_dict(torch.load('./sphere20a_20171020.pth'))
        for params in f.parameters():
            params.require_grad = False
        
        self.model['f'] = f
        self.model['g'] = D(128, alpha=0.2)
        self.model['d'] = G(channels=512)
        
        if self.use_gpu:
            self.model['g'] = self.model['g'].cuda()
            self.model['d'] = self.model['d'].cuda()
            
        self.up = nn.Upsample(size=(96,96), mode='bilinear')
        self.pad = ZeroPadBottom(112)
        
    def make_loss_func(self):
        
        self.lossCE = nn.CrossEntropyLoss().cuda()
        self.lossMSE = nn.MSELoss().cuda()
        lab0, lab1, lab2 = (torch.LongTensor(self.batch_size) for i in range(3))
        lab0 = Variable(lab0.cuda())
        lab1 = Variable(lab1.cuda())
        lab2 = Variable(lab2.cuda())
        
        lab0.data.resize_(self.batch_size).fill_(0)
        lab1.data.resize_(self.batch_size).fill_(1)
        lab2.data.resize_(self.batch_size).fill_(2)
        
        self.lab0 = lab0
        self.lab1 = lab1
        self.lab2 = lab2
        
        self.make_generator_loss_func()
        self.make_discriminator_loss_func()
        self.make_smooth_func()
        self.make_dist_func_targ_domain()
    
    def make_opt(self):
        
        self.generator_opt = optim.Adam(self.model['g'].parameters(), lr = 2e-4, betas=(0.5, 0.999), weight_decay=1e-6)
        self.discriminator_opt = optim.Adam(self.model['d'].parameters(), lr = 2e-4, betas=(0.5, 0.999), weight_decay=1e-6)
        
        self.generator_lr_sche = MultiStepLR(self.generator_opt, milestones=[15000], gamma=0.1)
        self.discriminator_lr_sche = MultiStepLR(self.discriminator_opt, milestones=[15000], gamma=0.1)
    
    def cos_sim(self, x, y):
        ab = torch.sum(x*y, dim=1)
        a = torch.sqrt(torch.sum(x*x, dim=1))
        b = torch.sqrt(torch.sum(y*y, dim=1))
        sim = ab/(a*b)
        avg_sim = torch.mean(sim)
        cos_loss = 1-avg_sim
        return cos_loss
        
    def make_generator_loss_func(self):
        
        def gloss(s_f, s_g_f, s_d_g, t, t_g, t_d_g, al, be, gam):
            l_gang = self.lossCE(s_d_g.squeeze(), self.lab2)+self.lossCE(t_d_g.squeeze(), self.lab2)
            l_const = self.cos_sim(s_f.detach(), s_g_f)
            ltv = self.smooth_func(s_g)
            ltid = self.lossMSE()
            
            return l_gang + al*l_const + be*ltid + gam*ltv
        
        self.generator_loss_func = gloss
        
    def make_discriminator_loss_func(self):
        
        def dloss(s_d_g, t_d_g, t_d):
            return self.lossCE(s_d_g.squeeze(), self.lab0) + self.lossCE(t_d_g.squeeze(), self.lab1) + self.lossCE(t_d.squeeze(), self.lab2)
        
        self.discriminator_loss_func = dloss
        
    def train(self, n_epo, **kwargs):
        
        l = min(len(self.source_train_loader), len(self.target_train_loader))
        msimg_count = 0
        tot_batch = 0
        
        for e in range(n_epo):
            
            source_data_iter = iter(self.source_train_loader)
            target_data_iter = iter(self.target_train_loader)
            
            for i in range(l):
                
                self.generator_lr_sche.step()
                self.discriminator_lr_sche.step()
                
                msimg_count+=1
                
                if msimg_count >= len(self.source_train_loader):
                    msimg_count = 0
                    source_data_iter = iter(self.source_train_loader)
                    
                source_data = source_data_iter.next()
                target_data = target_data_iter.next()
                
                if self.batch_size != source_data.size(0) or self.batch_size != target_data.size(0):
                    continue
                
                tot_batch+=1
                
                if self.use_gpu:
                    source_data = Variable(source_data.float().cuda())
                    target_data = Variable(target_data.float().cuda())
                else:
                    source_data = Variable(source_data.float())
                    target_data = Variable(target_data.float())
                    
                for param in self.model['d'].parameters():
                    param.requires_grad = True
                self.model['d'].zero_grad()
                
                source_data_pad = self.pad(source_data)
                s_f = self.model['f'](source_data_pad)
                s_g = self.model['g'](s_f)
                s_g = self.up(s_g)
#                 s_g_detach = s_g.detach()
                s_d_g = self.model['d'](s_g)
                
                target_data_pad = self.pad(target_data)
                t_f = self.model['f'](target_data_pad)
                t_g = self.model['g'](t_f)
                t_g = self.up(t_g)
#                 t_g_detach = t_g.detach()
                t_d_g = self.model['d'](t_g)
                
                t_d = self.model['d'](target_data)
                
                d_loss = discriminator_loss_func(s_d_g, t_d_g, t_d)
                d_loss.backward()
                self.discriminator_opt.step()
                
                for param in self.model['d'].parameters():
                    param.requires_grad = False
                self.model['g'].zero_grad()
                
                source_data_pad = self.pad(source_data)
                s_f = self.model['f'](source_data_pad)
                s_g = self.model['g'](s_f)
                s_g = self.up(s_g)
#                 s_g_detach = s_g.detach()
                s_d_g = self.model['d'](s_g)
                s_g_pad = self.pad(s_g)
                s_g_f = self.model['f'](s_g_pad)
                
                target_data_pad = self.pad(target_data)
                t_f = self.model['f'](target_data_pad)
                t_g = self.model['g'](t_f)
                t_g = self.up(t_g)
#                 t_g_detach = t_g.detach()
                t_d_g = self.model['d'](t_g)
                
                g_loss = generator_loss_func(s_f, s_g_f, s_d_g, t, t_g, t_d_g)
                g_loss.backward()
                self.generator_opt.step()
                