In [None]:
import matplotlib
matplotlib.use('Agg')
%matplotlib inline
import matplotlib.pyplot as plt
import itertools

import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch

import numpy as np
import random
import copy
import time
from sklearn.model_selection import train_test_split

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"


from utils import CustomDataset, GANLoss, Vgg19
from generator import AttU_Net
from discriminator import Discriminator



In [None]:
class Args:
    epoch=0
    n_epochs=200
    batchSize=1
    dataroot='../hair_swap/data/jan04_px96.npz'
    lr=0.0005
    decay_epoch=100
    size=96
    input_nc=3
    output_nc=3
    cuda=True
    n_cpu=32
    
opt = Args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
file = np.load(opt.dataroot)
x_1 = file['only_face']
x_2 = file['only_hair']
y = file['both']



x_con = np.concatenate((x_1,x_2), -1)

x_train, x_test, y_train, y_test = train_test_split(x_con, y, test_size=0.2, random_state=42)
x_train1 = x_train[:, :, :, 0:3]
x_train2 = x_train[:, :, :, 3:6]

x_test1 = x_test[:, :, :, 0:3]
x_test2 = x_test[:, :, :, 3:6]

In [None]:
transform_train = transforms.Compose([
    transforms.RandomAffine(degrees=[-10,10], translate=[0.00,0.08], scale=[0.65,1.00], shear=5, fillcolor=(255,255,255)),
    transforms.ToTensor(),
    transforms.RandomErasing(scale=[0.05, 0.08], ratio=[0.02,0.05], p=0.5),
])

transform_train2 = transforms.Compose([
    transforms.RandomAffine(degrees=[5,10], translate=[0.10,0.25], scale=[0.80,1.20], shear=0, fillcolor=(255,255,255)),
    transforms.ToTensor(),
])


transform_test = transforms.Compose([
    transforms.ToTensor(),
])




In [None]:
lim=-1
train_set = CustomDataset(x_train1[:lim], x_train2[:lim], y_train[:lim],
                          transform_face=transform_train2, 
                          transform_hair=transform_train, 
                          transform_others=transform_test
                         )
trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=0)

test_set = CustomDataset(x_test1, x_test2, y_test, 
                         transform_face=transform_test, 
                         transform_hair=transform_test, 
                         transform_others=transform_test
                        )
testloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=0)

In [None]:
counter = 0
to_skip = [0, 1]

for x1, x2, y in testloader:
    
    counter+=1
    if counter in to_skip:
        continue
    
    x1 = np.transpose(x1, (0, 2, 3, 1))
    x2 = np.transpose(x2, (0, 2, 3, 1))
    y = np.transpose(y, (0, 2, 3, 1))
    f, axarr = plt.subplots(1,3, figsize=(12,6))
    
    axarr[0].imshow(x1[0])
    axarr[0].axis('off')

    axarr[1].imshow(x2[0])
    axarr[1].axis('off')

    axarr[2].imshow(y[0])
    axarr[2].axis('off')
      

    plt.show()

    
    
    break

In [None]:
generator = AttU_Net(opt.input_nc, opt.output_nc)
generator = nn.DataParallel(generator)
generator.to(device)

In [None]:
discriminator = Discriminator(opt.input_nc)

discriminator = nn.DataParallel(discriminator)
discriminator.to(device)

In [None]:
perceptual_model = Vgg19()

perceptual_model = nn.DataParallel(perceptual_model)
perceptual_model.to(device)

In [None]:
num_epochs = 200

In [None]:
criterion_g = torch.nn.MSELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr,
                             betas=(0.5, 0.999)
                            )


criterion_d = GANLoss().to(device)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.lr,
                             betas=(0.5, 0.999)
                            )

criterion_p = torch.nn.MSELoss()
    


In [None]:
train_total = len(train_set)
train_batches = len(trainloader)

test_total = len(test_set)
test_baches = len(testloader)


In [None]:
patience = 0    # Bad epoch counter
best_loss = 1024

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_g, factor=0.1, patience=70, verbose=True, min_lr=0.00000001,)

In [None]:
path_checkpoint = './att_unet_pivtons_for_hair_96_again.pth'

try:
    generator.load_state_dict(torch.load(path_checkpoint))
except:
    print("Generator checkpoint path not available")
    
path_checkpoint2 = './att_unet_pivtons_for_hair_discrimiator_96_again.pth'

try:
    discriminator.load_state_dict(torch.load(path_checkpoint2))    
except:
    print("Discriminator checkpoint path not available")



In [None]:
count = 0
for epoch in range(num_epochs):
    # Train
    
    
    generator.train()
    discriminator.train()

    train_loss = 0
    train_correct = 0

    start_time = time.time()
    for x1, x2, y in trainloader:
        x1 = x1.to(device)
        x2 = x2.to(device)
        y = y.to(device)

        #forward
        fake = generator(x1,x2)
        
        
        #########################
        # (1) Calculate Perceptual Loss
        #########################

        fake_feature = perceptual_model(fake)
        real_feature = perceptual_model(y)

        loss_p = 0
        for i in range(5):
            loss_p += criterion_p(fake_feature[i], real_feature[i])        
        
        
        ##########################
        # (2) Update Discriminator
        ##########################
        
        #Train with fake
        optimizer_d.zero_grad()
        pred_fake = discriminator(fake)
        loss_d_fake = criterion_d(pred_fake, False)
        
        #Train with real
        pred_real = discriminator(y)
        loss_d_real = criterion_d(pred_real, True)
        
        #Average and update
        loss_d_total = (loss_d_fake + loss_d_real) * 0.5
        loss_d_total.backward(retain_graph=True)
        optimizer_d.step()
        
        
        ###########################
        # (3) Update Generator
        ###########################
        optimizer_g.zero_grad()
        
        #Get Discriminator Loss
        pred_fake = discriminator(fake)
        loss_gan = criterion_d(pred_fake, True)
        
        
        #Get Generator Loss
        loss_g = criterion_g(fake, y)
        
        loss_total = loss_gan + loss_g + loss_p
        loss_total.backward(retain_graph=True)
        optimizer_g.step()

        train_loss += loss_total.item()

    train_loss = train_loss / train_batches

    scheduler.step(1.)

    end_time = time.time()
    
    print('[%2d / %d] train_loss: %.5f  ' % (epoch+1, num_epochs , train_loss), end = ' ')
        
    generator.eval()
    
    
    ### test acc ###
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data in testloader:
            images1, images2, labels = data[0].to(device), data[1].to(device), data[2].to(device)
            outputs = generator(images1, images2)
            
            loss = criterion_g(outputs, labels)
            test_loss += loss.item()
            
            
        test_loss = test_loss / test_baches
        
        if(test_loss < best_loss):
            torch.save(generator.state_dict(), path_checkpoint)
            torch.save(discriminator.state_dict(), path_checkpoint2)
            best_loss = test_loss
        print('test_loss: %.5f -- Best loss: %.5f --- %.2f seconds' %(test_loss, best_loss, (end_time-start_time)))    
        
        
        
    if(epoch % 5 == 0):
        temp_out = copy.deepcopy(outputs.detach().cpu())
        print("----------------------------------EPOCH " + str(epoch) + "-----------------------------------------------")
        
        print("--------------------TRAINING DATA-----------------------")
        outputs = generator(x1,x2)
        x1 = np.transpose(x1.cpu(), (0, 2, 3, 1))
        x2 = np.transpose(x2.cpu(), (0, 2, 3, 1))
        o = np.transpose(outputs.detach().cpu(), (0, 2, 3, 1))
        y = np.transpose(y.cpu(), (0, 2, 3, 1))
        f, axarr = plt.subplots(1,4, figsize=(16,9))

        axarr[0].imshow(x1[0])
        axarr[0].axis('off')

        axarr[1].imshow(x2[0])
        axarr[1].axis('off')

        axarr[2].imshow(o[0])
        axarr[2].axis('off')


        axarr[3].imshow(y[0])
        axarr[3].axis('off')    
        plt.show()
        print("--------------------TRAINING DATA-----------------------")
        
        
        print("--------------------TESTING DATA-----------------------")
        x1 = np.transpose(images1.cpu(), (0, 2, 3, 1))
        x2 = np.transpose(images2.cpu(), (0, 2, 3, 1))
        o = np.transpose(temp_out, (0, 2, 3, 1))
        y = np.transpose(labels.cpu(), (0, 2, 3, 1))
        f, axarr = plt.subplots(1,4, figsize=(16,9))

        axarr[0].imshow(x1[0])
        axarr[0].axis('off')

        axarr[1].imshow(x2[0])
        axarr[1].axis('off')

        axarr[2].imshow(o[0])
        axarr[2].axis('off')


        axarr[3].imshow(y[0])
        axarr[3].axis('off')
        plt.show()
        print("--------------------TESTING DATA-----------------------")
        
        
        
        
        print("--------------------MIXED DATA-----------------------")
        x1 = np.transpose(images1.cpu(), (0, 2, 3, 1))
        x2 = np.transpose(images2.cpu(), (0, 2, 3, 1))
        outputs = generator(images1[0:1],images2[1:2])
        o = np.transpose(outputs.detach().cpu(), (0, 2, 3, 1))
        y = np.transpose(labels.cpu(), (0, 2, 3, 1))
        f, axarr = plt.subplots(1,5, figsize=(16,9))

        axarr[0].imshow(x1[0])
        axarr[0].axis('off')

        axarr[1].imshow(x2[1])
        axarr[1].axis('off')

        axarr[2].imshow(o[0])
        axarr[2].axis('off')


        axarr[3].imshow(y[0])
        axarr[3].axis('off')

        axarr[4].imshow(y[1])
        axarr[4].axis('off')
        plt.show()
        print("--------------------MIXED DATA-----------------------")
        


        plt.show()
        print("----------------------------------EPOCH " + str(epoch) + "-----------------------------------------------")

In [None]:
torch.save(generator.module.state_dict(), path_checkpoint)
torch.save(discriminator.state_dict(), path_checkpoint2)