In [1]:
import torch
import torchvision.models as models
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import glob
from torch.utils.data.dataset import Dataset

In [13]:
def LoadImage(fname):
    # load image and convert to tensor wrapped in a variable
    loader = transforms.Compose([transforms.Scale((128, 128)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.40), (0.229, 0.224, 0.225))])
    image = Image.open(fname)
    data = loader(image)
    data = Variable(data, requires_grad=False)
    data = data.unsqueeze(0)
    return data

def ViewImage(Data):
    # convert torch tensor back into image and view it
    data = torch.squeeze(Data.data)
    f = transforms.ToPILImage()
    img = f(data)
    img.show()

def SaveImage(Data):
    # convert torch tensor back into image and view it
    data = torch.squeeze(Data.data)
    f = transforms.ToPILImage()
    img = f(data)
    img.save('output.png')
    
def Gram(Fmap):
    # calculate Gram matrix from feature map
    numImages = Fmap.size()[0]
    numFeatures = Fmap.size()[1]
    W = Fmap.size()[2]
    H = Fmap.size()[3]
    Fmap = Fmap.view(numImages, numFeatures, W*H)
    return torch.bmm(Fmap, Fmap.transpose(1, 2)) / (numImages*numFeatures*W*H)


def StyleGram(style):
    Style_layers = [3, 8, 15, 22]
    style_gram = {}
    for i in range(max(Style_layers)+1):
        layer = vgg.features[i]
        style = layer(style)
        if i in Style_layers:
            style_gram[i] = Gram(style)
    return style_gram

def normalize(x):
    mean = [0.485, 0.456, 0.40]
    std = [0.229, 0.224, 0.225]
    y = x.clone()
    y[:,0,:,:] = (x[:,0,:,:] - mean[0]) / std[0]
    y[:,1,:,:] = (x[:,1,:,:] - mean[1]) / std[1]
    y[:,2,:,:] = (x[:,2,:,:] - mean[2]) / std[2]
    return y


def TotalLoss(pred, content):
    global Style_Gram
    pred = normalize(pred)

    Loss = Variable(torch.zeros(1))
#     Style_layers = [3, 8, 15, 22]max(Style_layers)
    Content_layer = 8
    alpha = 0
    for i in range(Content_layer+1):
        layer = vgg.features[i]
        pred = layer(pred)
        if i == Content_layer:
            content = layer(content)
            ContentLoss = torch.sum((pred - content)**2) / np.prod(content.size())
            Loss = Loss + ContentLoss
        elif i < Content_layer:
            content = layer(content)
#         if i in Style_layers:
#             Loss = Loss + alpha * torch.sum((Gram(pred) - Style_Gram[i])**2)
    return Loss, ContentLoss, Loss - ContentLoss

In [3]:
# load pretrained VGG16 network
vgg = models.vgg16(True)
for paras in vgg.parameters():
    paras.requires_grad = False

In [4]:
x = Variable(torch.randn(1, 3, 128, 128), requires_grad=True)
# load content image and style image
style_data = LoadImage('van.jpg')
Style_Gram = StyleGram(style_data)
content_data = LoadImage('content.jpg')

In [9]:
optimizer = torch.optim.LBFGS([x], lr=0.8)

for i in range(10):
    def closure():
        optimizer.zero_grad()
        loss, CL, SL = TotalLoss(x, content_data)
        print(i, 'loss:', loss.data[0], CL.data[0], SL.data[0])
        loss.backward()
        return loss
    optimizer.step(closure)

(0, 'loss:', 0.03996054083108902, 0.03996054083108902, 0.0)
(0, 'loss:', 0.03991414234042168, 0.03991414234042168, 0.0)
(0, 'loss:', 0.039859838783741, 0.039859838783741, 0.0)
(0, 'loss:', 0.03981596603989601, 0.03981596603989601, 0.0)
(0, 'loss:', 0.03978440538048744, 0.03978440538048744, 0.0)
(0, 'loss:', 0.03975490853190422, 0.03975490853190422, 0.0)
(0, 'loss:', 0.03971362113952637, 0.03971362113952637, 0.0)
(0, 'loss:', 0.039675962179899216, 0.039675962179899216, 0.0)
(0, 'loss:', 0.03962545841932297, 0.03962545841932297, 0.0)
(0, 'loss:', 0.03958992660045624, 0.03958992660045624, 0.0)
(0, 'loss:', 0.03955646604299545, 0.03955646604299545, 0.0)
(0, 'loss:', 0.03949887678027153, 0.03949887678027153, 0.0)
(0, 'loss:', 0.03945951536297798, 0.03945951536297798, 0.0)
(0, 'loss:', 0.03941833972930908, 0.03941833972930908, 0.0)
(0, 'loss:', 0.03938131779432297, 0.03938131779432297, 0.0)
(0, 'loss:', 0.03934456780552864, 0.03934456780552864, 0.0)
(0, 'loss:', 0.039308443665504456, 0.03930

In [14]:
SaveImage(x)