## import packages

In [None]:
import time
import os
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')
print(device)

import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt
from PIL import Image
from collections import OrderedDict
import numpy as np

cuda


## Directories and files

In [None]:
image_dir = os.getcwd() + '/Images/'
model_dir = os.getcwd() + '/Models/'
output_dir = os.getcwd() + '/Output/'

## VGG NN and functions

In [None]:
#vgg definition that conveniently let's you grab the outputs from any layer
class VGG(nn.Module):
    def __init__(self, pool='max'):
        super(VGG, self).__init__()
        #vgg modules
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        if pool == 'max':
            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        elif pool == 'avg':
            self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
            
    def forward(self, x, out_keys):
        out = {}
        out['r11'] = F.relu(self.conv1_1(x))
        out['r12'] = F.relu(self.conv1_2(out['r11']))
        out['p1'] = self.pool1(out['r12'])
        out['r21'] = F.relu(self.conv2_1(out['p1']))
        out['r22'] = F.relu(self.conv2_2(out['r21']))
        out['p2'] = self.pool2(out['r22'])
        out['r31'] = F.relu(self.conv3_1(out['p2']))
        out['r32'] = F.relu(self.conv3_2(out['r31']))
        out['r33'] = F.relu(self.conv3_3(out['r32']))
        out['r34'] = F.relu(self.conv3_4(out['r33']))
        out['p3'] = self.pool3(out['r34'])
        out['r41'] = F.relu(self.conv4_1(out['p3']))
        out['r42'] = F.relu(self.conv4_2(out['r41']))
        out['r43'] = F.relu(self.conv4_3(out['r42']))
        out['r44'] = F.relu(self.conv4_4(out['r43']))
        out['p4'] = self.pool4(out['r44'])
        out['r51'] = F.relu(self.conv5_1(out['p4']))
        out['r52'] = F.relu(self.conv5_2(out['r51']))
        out['r53'] = F.relu(self.conv5_3(out['r52']))
        out['r54'] = F.relu(self.conv5_4(out['r53']))
        out['p5'] = self.pool5(out['r54'])
        return [out[key] for key in out_keys]

In [None]:
'''
# gram matrix and loss
class GramMatrix(nn.Module):
    def forward(self, input):
        b,c,h,w = input.size()
        F = input.view(b, c, h*w)
        F = torch.nn.functional.normalize(F, dim = 2) #normalize the matrix, don't emphasize on bright-darkness
        # entries of F looks small
        F = F.multiply(100)
        G = torch.bmm(F, F.transpose(1,2)) 
        #G.div_(h*w)
        return G
'''

# gram matrix and loss
class GramMatrix(nn.Module):
    def forward(self, input):
        b,c,h,w = input.size()
        F = input.view(b, c, h*w)
        F = torch.nn.functional.normalize(F, dim = 2) #normalize the matrix, don't emphasize on bright-darkness
        # entries of F looks small
        F = F.multiply(150)
        G = torch.bmm(F, F.transpose(1,2)) 
        #G.div_(h*w)
        return G

class GramMSELoss(nn.Module):
    def forward(self, input, target):
        out = nn.MSELoss()(GramMatrix()(input), target)
        return(out)

## Adapt images?

In [None]:
# pre and post processing for images
img_size = 512 
prep = transforms.Compose([transforms.Resize(img_size),
                           transforms.ToTensor(),
                           transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to BGR
                           transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], #subtract imagenet mean
                                                std=[1,1,1]),
                           transforms.Lambda(lambda x: x.mul_(255)),
                          ])
postpa = transforms.Compose([transforms.Lambda(lambda x: x.mul_(1./255)),
                           transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], #add imagenet mean
                                                std=[1,1,1]),
                           transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to RGB
                           ])
postpb = transforms.Compose([transforms.ToPILImage()])
def postp(tensor): # to clip results in the range [0,1]
    t = postpa(tensor)
    t[t>1] = 1    
    t[t<0] = 0
    img = postpb(t)
    return img

## Model

In [None]:
def run_transfer(style_name, content_name,
                 output_name="output",
                 output_dir=output_dir,
                 pool_method="avg",
                 style_layers=['r11','r21','r31','r41','r51'],
                 content_layers=['r42'],
                 content_weight=1e0,
                 style_weight=1e3,
                 style_layer_weights=[64,128,256,512,512], 
                 init_method="content", 
                 max_iter=1000, 
                 show_iter=50,
                 output_metrics=True):

    # set up time and loss trackers
    start_time = time.time()
    times = []
    losses = []

    #get network
    vgg = VGG(pool=pool_method)
    vgg.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))
    for param in vgg.parameters():
        param.requires_grad = False
    if torch.cuda.is_available():
        vgg.cuda()

    #load images, ordered as [style_image, content_image]
    img_dirs = [image_dir, image_dir]
    img_names = [style_name, content_name]
    imgs = [Image.open(img_dirs[i] + name) for i,name in enumerate(img_names)]
    imgs_torch = [prep(img) for img in imgs]
    if torch.cuda.is_available():
        imgs_torch = [Variable(img.unsqueeze(0).cuda()) for img in imgs_torch]
    else:
        imgs_torch = [Variable(img.unsqueeze(0)) for img in imgs_torch]
    style_image, content_image = imgs_torch

    #initialize the output image
    if init_method == "random":
        opt_img = Variable(torch.randn(content_image.size()).type_as(content_image.data), requires_grad=True)
    elif init_method == "content":
        opt_img = Variable(content_image.data.clone(), requires_grad=True)
    elif init_method == "style":
        resize_to_content = transforms.Resize((imgs[1].height, imgs[1].width))
        style_image_resized = resize_to_content(imgs[0])
        if torch.cuda.is_available():
            opt_img = Variable(prep(style_image_resized).unsqueeze(0).cuda(), requires_grad=True)
        else:
            opt_img = Variable(prep(style_image_resized).unsqueeze(0), requires_grad=True)

    optimizer = optim.LBFGS([opt_img])
    n_iter=[0]

    #define layers, loss functions, weights and compute optimization targets
    loss_layers = style_layers + content_layers
    loss_fns = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers)
    if torch.cuda.is_available():
        loss_fns = [loss_fn.cuda() for loss_fn in loss_fns]
        
    #these are good weights settings:
    style_weights = [style_weight/n**2 for n in style_layer_weights]
    content_weights = [content_weight]
    weights = style_weights + content_weights

    #compute optimization targets
    style_targets = [GramMatrix()(A).detach() for A in vgg(style_image, style_layers)]
    content_targets = [A.detach() for A in vgg(content_image, content_layers)]
    targets = style_targets + content_targets

    #run style transfer
    times.append(time.time()-start_time)
    while n_iter[0] <= max_iter:

        def closure():
            optimizer.zero_grad()
            out = vgg(opt_img, loss_layers)
            layer_losses = [weights[a] * loss_fns[a](A, targets[a]) for a,A in enumerate(out)]
            loss = torch.stack(layer_losses, dim=0).sum(dim=0)
            loss.backward()
            n_iter[0]+=1
            if n_iter[0]%show_iter == (show_iter-1):
                print('Iteration: %d, loss: %f'%(n_iter[0]+1, loss.item()))
                losses.append(loss.item())
                times.append(time.time()-start_time)
            return loss
        
        optimizer.step(closure)
        
    #display result
    out_img = postp(opt_img.data[0].cpu().squeeze())
    out_img.save(output_dir + output_name + ".png")
    if output_metrics: return losses, times

## Running code

style_name, content_name, variation, style_layers, content_layers, style_weights, init_method="random", max_iter=500, show_iter=50, output_dir=output_dir

In [None]:
content_images = ["Content/10.png" for i in range(10)]
style_images = ["Style/01.png", "Style/02.png", "Style/03.png", "Style/04.png", "Style/05.png", 
                "Style/06.png", "Style/07.png", "Style/08.png", "Style/09.png", "Style/10.png"]

In [None]:
align_losses, align_times = list(), list()
for i in range(len(content_images)):
    l,t = run_transfer(style_name = style_images[i], 
                       content_name = content_images[i],
                       max_iter = 2000, 
                       output_name = "s" + style_images[i][6:8] + "_align",
                       output_dir = output_dir + "portrait_aware_10/",
                       output_metrics = True)
    align_losses.append(l)
    align_times.append(t)

Iteration: 50, loss: 53583.304688
Iteration: 50, loss: 54582.390625
Iteration: 50, loss: 50437.375000
Iteration: 50, loss: 59566.734375
Iteration: 50, loss: 53388.316406
Iteration: 50, loss: 39823.324219
Iteration: 50, loss: 39757.039062
Iteration: 50, loss: 63783.648438
Iteration: 50, loss: 75314.265625
Iteration: 50, loss: 99806.093750


In [None]:
align_losses = np.array(align_losses)
align_times = np.array(align_times)

np.savetxt(output_dir+"portrait_align_losses.csv", align_losses, delimiter=",")
np.savetxt(output_dir+"portrait_align_times.csv", align_times, delimiter=",")