In [None]:
import torch
torch.__version__

In [None]:
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import copy

In [None]:
use_cuda = torch.cuda.is_available()
use_cuda = False
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
print(use_cuda)

In [None]:
imsize = 512 if use_cuda else 128
loader = transforms.Compose([transforms.Resize(imsize), transforms.ToTensor()])

def image_loader(image_name):
    image = Image.open(image_name)
    image = Variable(loader(image))
    # fake batch demension to feed into the network

    image = image.unsqueeze(0)
    return image

In [None]:
style_filename = "./images/picasso.jpg"
content_filename = "./images/dancing.jpg"

import os.path
assert os.path.isfile(style_filename) and os.path.isfile(content_filename), "Content or style image do not exist."


style_img = image_loader(style_filename).type(dtype)
content_img = image_loader(content_filename).type(dtype)

assert style_img.size() == content_img.size(), \
    "content and style image need to have same size"

In [None]:
unloader = transforms.ToPILImage()

plt.ion()

def imshow(tensor, title=None):
    image = tensor.clone().cpu()
    image = image.view(3, imsize, imsize) # remove fake batch dimension
    image = unloader(image)
    
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit to update the plots
    
plt.figure()
imshow(style_img.data, title="Style Image")

plt.figure()
imshow(content_img.data, title="Content Image")

In [None]:
class ContentLoss(nn.Module):
    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.criterion = nn.MSELoss()
        
    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss

In [None]:
class GrammMatrix(nn.Module):
    
    def forward(self, input):
        a, b, c, d = input.size() # a=batch_size, b=feature_maps, c&d=dimensions
        features = input.view(a*b, c*d)
        G = torch.mm(features, features.t())
        
        #normalize
        return G.div(a*b*c*d)
        

In [None]:
class StyleLoss(nn.Module):
    
    def __init__(self, target, weight):
        super(StyleLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.gramm = GrammMatrix()
        self.criterion = nn.MSELoss()
        
    def forward(self, input):
        self.output = input.clone()
        self.G = self.gramm(input)
        self.G.mul_(self.weight)
        self.loss = self.criterion(self.G, self.target)
        return self.output
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss

In [None]:
cnn = models.vgg19(pretrained=True).features

if use_cuda:
    cnn = cnn.cuda()

In [None]:
content_layers_default = ['conv_4']
style_layers_default = ['conv_1','conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_style_model_and_losses(cnn, style_img, content_img, 
                               style_weight=1000, content_weight=1, 
                               content_layers=content_layers_default,
                              style_layers=style_layers_default):
    
    cnn = copy.deepcopy(cnn)
    
    content_losses = []
    style_losses = []
    
    model = nn.Sequential()
    gram = GrammMatrix()
    
    if use_cuda:
        model = model.cuda()
        gram = gram.cuda()
        
    i = 1
    for layer in list(cnn):
        
        is_conv_layer = isinstance(layer, nn.Conv2d)
        is_relu_layer = isinstance(layer, nn.ReLU)
        is_maxpool_layer = isinstance(layer, nn.MaxPool2d)
        
        if is_conv_layer:
            name_prefix = "conv_"
        elif is_relu_layer:
            name_prefix = "relu_"
            i += 1
        elif is_maxpool_layer:
            name_prefix = "pool_"
        
        
        if is_conv_layer or is_relu_layer:
            name = name_prefix + str(i)
            model.add_module(name, layer)
            
            if name in content_layers:
                target = model(content_img).clone()
                content_loss = ContentLoss(target, content_weight)
                model.add_module("content_loss_"+str(i), content_loss)
                content_losses.append(content_loss)
                
            if name in style_layers:
                target_feature = model(style_img).clone()
                target_feature_gramm = gram(target_feature)
                style_loss = StyleLoss(target_feature_gramm, style_weight)
                model.add_module("style_loss_"+str(i), style_loss)
                style_losses.append(style_loss)
                
    return model, style_losses, content_losses
            
    

In [None]:
input_img = content_img.clone()
# input_img = Variable(torch.randn(content_img.data.size())).type(dtype)
plt.figure()
imshow(input_img.data, title="Input Image")

In [None]:
def input_param_optimizer(input_img):
    input_param = nn.Parameter(input_img.data)
    optimizer = optim.LBFGS([input_param])
    return input_param, optimizer

In [None]:
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=300,
                      style_weight=1000, content_weight=1):
    """Run the style transfer"""
    print("Building the style transfer model..")
    model, style_losses, content_losses = get_style_model_and_losses(cnn, style_img, content_img, style_weight, content_weight)
    input_param, optimizer = input_param_optimizer(input_img)
    
    print("Optimizing..")
    run = [0]
    while run[0] <= num_steps:
        
        def closure():
            input_param.data.clamp_(0,1)
            
            optimizer.zero_grad()
            model(input_param)
            style_score = 0
            content_score = 0
            for sl in style_losses:
                style_score += sl.backward()
            for cl in content_losses:
                content_score += cl.backward()
            
            run[0] += 1
            if run[0] % 1 == 0:
                print("run {}:".format(run))
                print("Style Loss: {:4f} Content Loss: {:4f}".format(style_score.data[0], content_score.data[0]))
                
            return style_score + content_score
        
        optimizer.step(closure)
        
    input_param.data.clamp_(0,1)
    
    return input_param.data
                

In [None]:
output = run_style_transfer(cnn, content_img, style_img, input_img)

plt.figure()
imshow(output, title="Output Image")

plt.ioff()
plt.show()
