In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from __future__ import print_function
import PIL.Image as Image
import matplotlib.pyplot as plt
import copy
import numpy as np
from scipy.misc import fromimage, toimage
from scipy.interpolate import interp1d

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

In [2]:
# desired size of the output image
# gpu 512 이상 시 out of memory error 발생
imsize = 512 if use_cuda else 128

transform = transforms.Compose([transforms.ToTensor()])

def image_loader(image_name, transform=None, max_size=None, shape=None):
    image = Image.open(image_name)
    
    if max_size is not None:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    
    if shape is not None:
        image = image.resize(shape, Image.LANCZOS)
    
    if transform is not None:
        image = transform(image)
        
    image = Variable(image)
    image = image.unsqueeze(0)
    
    return image.type(dtype)

In [3]:
# 여기는 코드에 있어서 써보려고 했지만 필요없더라고~~
# 이거 쓰면 이상하게 나왔어~
def match_histogram(content_img, generated_img):    
    oldshape = generated_img.shape
    source = generated_img.ravel()
    template = content_img.ravel()
    
    s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, return_counts=True)
    t_values, t_counts = np.unique(template, return_counts=True)
    
    s_quantiles = np.cumsum(s_counts).astype(np.float64)
    s_quantiles /= s_quantiles[-1]
    t_quantiles = np.cumsum(t_counts).astype(np.float64)
    t_quantiles /= t_quantiles[-1]
    
    interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
    
    return interp_t_values[bin_idx].reshape(oldshape)

In [36]:
def orginal_color_transform(content_img, generated_img, hist_match=0, mode='YCbCr'):
    content_img = content_img.squeeze(0)
    content_img = content_img.data.cpu().numpy()
    #generated_img = Variable(generated_img)
    generated_img = generated_img.squeeze(0)
    generated_img = generated_img.data.cpu().numpy()
    
    content_img = fromimage(toimage(content_img, mode='RGB'), mode=mode)
    generated_img = fromimage(toimage(generated_img, mode='RGB'), mode=mode)
    
    if hist_match == 1:
        for channel in range(3):
            generated_img[:, :, channel] = match_histogram(generated_img[:, :, channel], content_img[:, :, channel])
    else:
        # 여기가 핵심이야~ 여기가 바꾸는 곳이야~
        generated_img[ :, :, 1:] = content_img[ :, :,  1:]
        
    generated_img = fromimage(toimage(generated_img, mode=mode), mode='RGB')
    content_img = fromimage(toimage(content_img, mode=mode), mode='RGB')
    generated_img = transform(generated_img)
    
    generated_img = Variable(generated_img).cuda()
    generated_img = generated_img.unsqueeze(0)
    return generated_img

In [5]:
class ContentLoss(nn.Module):

    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output

    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss

In [6]:
class GramMatrix(nn.Module):

    def forward(self, input):
        a, b, c, d = input.size()
        features = input.view(a * b, c * d)
        G = torch.mm(features, features.t())
        return G.div(a * b * c * d)

In [7]:
class StyleLoss(nn.Module):

    def __init__(self, target, weight):
        super(StyleLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.gram = GramMatrix()
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.output = input.clone()
        self.G = self.gram(input)
        self.G.mul_(self.weight)
        self.loss = self.criterion(self.G, self.target)
        return self.output

    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss

In [8]:
# VGG19를 Transfer Learning에 사용
cnn = models.vgg19(pretrained=True).features

# move it to the GPU if possible:
if use_cuda:
    cnn = cnn.cuda()

In [9]:
# desired depth layers to compute style/content losses :
content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']


def get_style_model_and_losses(cnn, style_img, content_img,
                               style_weight=1000, content_weight=1,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)

    # just in order to have an iterable access to or list of content/syle
    # losses
    content_losses = []
    style_losses = []

    model = nn.Sequential()  # the new Sequential module network
    gram = GramMatrix()  # we need a gram module in order to compute style targets

    # move these modules to the GPU if possible:
    if use_cuda:
        model = model.cuda()
        gram = gram.cuda()

    i = 1
    for layer in list(cnn):
        if isinstance(layer, nn.Conv2d):
            name = "conv_" + str(i)
            model.add_module(name, layer)

            if name in content_layers:
                # add content loss:
                target = model(content_img).clone()
                content_loss = ContentLoss(target, content_weight)
                model.add_module("content_loss_" + str(i), content_loss)
                content_losses.append(content_loss)

            if name in style_layers:
                # add style loss:
                target_feature = model(style_img).clone()
                target_feature_gram = gram(target_feature)
                style_loss = StyleLoss(target_feature_gram, style_weight)
                model.add_module("style_loss_" + str(i), style_loss)
                style_losses.append(style_loss)

        if isinstance(layer, nn.ReLU):
            name = "relu_" + str(i)
            model.add_module(name, layer)

            if name in content_layers:
                # add content loss:
                target = model(content_img).clone()
                content_loss = ContentLoss(target, content_weight)
                model.add_module("content_loss_" + str(i), content_loss)
                content_losses.append(content_loss)

            if name in style_layers:
                # add style loss:
                target_feature = model(style_img).clone()
                target_feature_gram = gram(target_feature)
                style_loss = StyleLoss(target_feature_gram, style_weight)
                model.add_module("style_loss_" + str(i), style_loss)
                style_losses.append(style_loss)

            i += 1

        if isinstance(layer, nn.MaxPool2d):
            name = "pool_" + str(i)
            model.add_module(name, layer)  # ***

    return model, style_losses, content_losses

In [10]:
# Optimizer (LBFGS)
def get_input_param_optimizer(input_img):
    input_param = nn.Parameter(input_img.data)
    optimizer = optim.LBFGS([input_param])
    return input_param, optimizer

In [11]:
# step : 300
# style weight = 1000 / content weight = 1
def train_style_transfer(cnn, content_img, style_img, input_img, num_steps=150,
                       style_weight=1000, content_weight=1):
    print('Starting style transfer!!')
    model, style_losses, content_losses = get_style_model_and_losses(cnn, style_img, content_img, style_weight, content_weight)
    input_param, optimizer = get_input_param_optimizer(input_img)
    
    run = [0]
    while run[0] <= num_steps:
        def steps():
            input_param.data.clamp_(0, 1)
            optimizer.zero_grad()
            model(input_param)
            style_score = 0
            content_score = 0

            for sl in style_losses:
                style_score += sl.backward()
            for cl in content_losses:
                content_score += cl.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}: Style Loss : {:4f} Content Loss: {:4f}".format(run, style_score.data[0], content_score.data[0]))

            return style_score + content_score

        optimizer.step(steps)

    # a last correction...
    input_param.data.clamp_(0, 1)

    return input_param.data

In [54]:
# Style or Content image 변경 시 여기서 변경
content_img = image_loader("images/Content/content.png", transform, max_size=imsize)
color_img = image_loader("images/Style/color.jpg", transform, shape=[content_img.size(3), content_img.size(2)])
style_img = image_loader("images/Style/style.jpg", transform, shape=[content_img.size(3), content_img.size(2)])
input_img = content_img.clone()

In [55]:
# Training
color_img = orginal_color_transform(color_img, content_img)
color_output = train_style_transfer(cnn, content_img, color_img, input_img)
input_img = content_img.clone()
output = train_style_transfer(cnn, content_img, style_img, input_img)

Starting style transfer!!
run [50]: Style Loss : 1.558883 Content Loss: 2.138358
run [100]: Style Loss : 0.301498 Content Loss: 1.707018
run [150]: Style Loss : 0.157995 Content Loss: 1.557527
Starting style transfer!!
run [50]: Style Loss : 3.795955 Content Loss: 6.441866
run [100]: Style Loss : 1.076885 Content Loss: 6.035984
run [150]: Style Loss : 0.689826 Content Loss: 5.704264


In [56]:
# Save Output Image
transform = transforms.Compose([transforms.ToTensor()])

color_output = Variable(color_output)
output = Variable(output)
output = orginal_color_transform(color_output, output)

output = output.data.cpu().numpy()
output = torch.from_numpy(output)

torchvision.utils.save_image(output, 'images/Output/ouput-comb.jpg')
#output = orginal_color_transform(content_img, output)
#torchvision.utils.save_image(output, 'images/Output/ouput-color.jpg')