In [139]:
%matplotlib inline

In [169]:
from __future__ import print_function

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.utils as utils 
from torch.autograd import Variable

import copy

In [170]:
# Constants for the Image
CONTENT_IMG = "content.jpg"
STYLE_IMG = "style.jpg"

STEPS = 500

IMSIZE = 512 if torch.cuda.is_available() else 128

MEAN=[0.485, 0.456, 0.406]
STD=[0.229, 0.224, 0.225]

CNTWEIGHT = 1
STLWEIGHT = 1000

#etc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = models.vgg19(pretrained=True).features.to(device).eval() #평가판을 사용해야 모델에 영향을 안주고 할 수 있다...?

In [171]:
#Images
convert = transforms.Compose([ #데이터셋을 가져올 때 형태를 변환해주는 부분.
    transforms.Resize(IMSIZE), #이미지 크기 변환
    transforms.ToTensor() #pytorch에서 사용하기 위한 tensor자료구조로 변환
     #받아오는 데이터를 노말라이징.(특정 부분이 너무 어둡거나 밝은 경우 데이터가 튀는 현상을 방지)
])
reconvert = transforms.ToPILImage()

def imageLoad(image_name):  #이미지를 텐서로 전환해 모델에 
    img = Image.open(image_name)
    img = convert(img).unsqueeze(0).clone()
    return img.to(device, torch.float)

def showImage(tensor): #결과를 보여주기 위한 작업
    image = tensor.cpu().clone() #원본이미지 손상 방지용 clone
    image = image.squeeze(0) #delete fake dimension
    image = reconvert(image)
    plt.imshow(image)
    
style_img = imageLoad(STYLE_IMG) 
content_img = imageLoad(CONTENT_IMG)
#plt.figure()
#showImage(style_img)
#plt.figure()
#showImage(content_img)

In [172]:
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
w = 0.2 #스타일 각 계층에 대한 가중치 = 1/5 

In [173]:
input_img = torch.randn(content_img.data.size(), device=device)

In [174]:
def ContentLoss(y, y_pred):
    
    loss = F.mse_loss(y,y_pred) #/2.0
    return loss

def gramMatrix(input):
    b, c, h, w = input.size() #batch size(1), chanel, height, width
    features = input.view(b*c, h*w) 
    gram = torch.mm(features, features.t())
    return gram

def StyleLoss(y_pred, y,layer):
    
    N = layer.out_channels #the number of feature maps at layer L -> 채널 수
    M = layer.kernel_size[0] #height * width of feature maps at layer L
    A = gramMatrix(y).detach() #original list
    G = gramMatrix(y_pred).detach() #generated list
    
    E = F.mse_loss(G,A)
    #E = (1/4* N*N*M*M) * F.mse_loss(G,A) 
    return E  

In [175]:
class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        # .view the mean and std to make them [C x 1 x 1] so that they can
        # directly work with image Tensor of shape [B x C x H x W].
        # B is batch size. C is number of channels. H is height and W is width.
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        # normalize img
        return (img - self.mean) / self.std

In [186]:
def get_losses(vgg, style_img, content_img, input_img):
    
    normalization = Normalization(MEAN, STD).to(device)
    
    content_losses = []
    style_losses = []
    i=0
    for layer in vgg.children():  
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
            
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
            #layer = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
            
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
    
        
        if name in content_layers:
            # add content loss:
            target = vgg(content_img).detach()
            ans = vgg(input_img).detach()
            content_loss = ContentLoss(target, ans)
            content_losses.append(content_loss)    
        
        if name in style_layers:
            # add style loss:
            target_feature = vgg(style_img).detach()
            ans_feature = vgg(input_img).detach()
            style_loss = StyleLoss(target_feature, ans_feature, layer)
            style_losses.append(style_loss)
        
    return style_losses, content_losses

In [187]:
def get_input_optimizer(input_img):
    # this line to show that input is a parameter that requires a gradient
    optimizer = optim.LBFGS([input_img.requires_grad_()])
    return optimizer

In [188]:
def makeOutput(vgg, content_img, style_img, input_img):
    
    style_losses, content_losses = get_losses(vgg,
                                              style_img, content_img, input_img)
    optimizer = get_input_optimizer(input_img)
    
    num = 0
    print('Optimizing..')
    while num <= STEPS:
        
        def closure():
            # 입력 이미지의 업데이트된 값들을 보정합니다
            input_img.data.clamp_(0, 1)
        
            optimizer.zero_grad()
            vgg(input_img)

            style_total = 0
            content_total = 0
            
            for i in content_losses:
                content_total += i
            for i in style_losses:
                style_total += i #*w
                
            total_loss = style_total * STLWEIGHT + content_total * CNTWEIGHT
            
            print("total:",total_loss)
            
            total_loss.backward()
            
            num += 1
            if num % 50 == 0:
                print("run {}:".format(num))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_total.item(), content_total.item()))
                print()
            return style_total + content_total
        optimizer.step(closure)
        
    input_img.data.clamp_(0, 1)
    
    return input_img

In [189]:
result_img = makeOutput(vgg, content_img, style_img, input_img)

plt.figure()
showImage(result_img)

Optimizing..
total: tensor(1.0505e+09)


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
class ContentLoss(nn.Module):

    def __init__(self, target,):
        super(ContentLoss, self).__init__()
        # 그라디언트를 동적으로 계산하는 데 사용되는 트리에서 대상 콘텐츠를 '분리' 합니다.
        # :이 값은 변수(variable)가 아니라 명시된 값입니다.
        # 그렇지 않으면 기준의 전달 메소드가 오류를 발생 시킵니다.
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

def gram_matrix(input):
    a, b, c, d = input.size()  # a=배치 크기(=1)
    # b=특징 맵의 크기
    # (c,d)=특징 맵(N=c*d)의 차원

    features = input.view(a * b, c * d)  # F_XL을 \hat F_XL로 크기 조정합니다

    G = torch.mm(features, features.t())  # 그램 곱을 수행합니다

    # 그램 행렬의 값을 각 특징 맵의 요소 숫자로 나누는 방식으로 '정규화'를 수행합니다.
    return G.div(a * b * c * d)
    
class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input