In [None]:
import torch
import torchvision.models as models
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import glob
from torch.utils.data.dataset import Dataset

In [None]:
class COCODataset(Dataset):
    def __init__(self, path, transform):
        self.images = glob.glob(path)
        self.transform = transform
    def __getitem__(self, index):
        fname = self.images[index]
        img = Image.open(fname)
        return self.transform(img)
    def __len__(self):
        return len(self.images)

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.47, 0.45, 0.4), (0.24, 0.24, 0.24))])
dataset = COCODataset('Dataset/PreData/*.jpg', transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)

In [None]:
def LoadImage(fname):
    # load image and convert to tensor wrapped in a variable
    loader = transforms.Compose([transforms.Scale((256, 256)), transforms.ToTensor()])
    image = Image.open(fname)
    data = loader(image)
    data = Variable(data, requires_grad=False)
    data = data.unsqueeze(0)
    return data

In [None]:
def ViewImage(Data):
    # convert torch tensor back into image and view it
    data = torch.squeeze(Data.data)
    f = transforms.ToPILImage()
    img = f(data)
    img.show()

In [None]:
def Gram(Fmap):
    # calculate Gram matrix from feature map
    numImages = Fmap.size()[0]
    numFeatures = Fmap.size()[1]
    W = Fmap.size()[2]
    H = Fmap.size()[3]
    Fmap = Fmap.view(numImages, numFeatures, W*H)
    return torch.bmm(Fmap, Fmap.transpose(1,2)) / (numImages*numFeatures*W*H)

In [None]:
# load pretrained VGG16 network
vgg = models.vgg16(True)
for paras in vgg.parameters():
    paras.requires_grad = False

In [None]:
def StyleGram(style):
    Style_layers = [3, 8, 15, 22]
    style_gram = {}
    for i in range(max(Style_layers)+1):
        layer = vgg.features[i]
        style = layer(style)
        if i in Style_layers:
            style_gram[i] = Gram(style)
    return style_gram

In [None]:
def TotalLoss(pred, content):
    global Style_Gram
    Loss = Variable(torch.zeros(1))
    Style_layers = [3, 8, 15, 22]
    Content_layer = 15
    alpha = 1
    for i in range(max(Style_layers)+1):
        layer = vgg.features[i]
        pred = layer(pred)
        if i == Content_layer:
            content = layer(content)
            Loss = Loss + torch.sum((pred - content)**2) / np.prod(content.size())
        elif i < Content_layer:
            content = layer(content)
        if i in Style_layers:
            Loss = Loss + alpha * torch.sum((Gram(pred) - Style_Gram[i])**2)
    return Loss

In [None]:
def StyleLoss(input1, input2):
    Loss = Variable(torch.zeros(1))
    layers = [3, 8, 15, 22]
    weights = [1, 1, 1, 1]
    for i in range(max(layers)+1):
        layer = vgg.features[i]
        input1 = layer(input1)
        input2 = layer(input2)
        if i in layers:
            Loss = Loss + torch.sum((Gram(input1) - Gram(input2))**2)
    return Loss

In [None]:
def ContentLoss(input1, input2):
    num_layer = 15
    for i in range(num_layer+1):
        layer = vgg.features[i]
        input1 = layer(input1)
        input2 = layer(input2)
    return torch.sum((input1 - input2)**2) / np.prod(input1.size())

In [None]:
class Res(nn.Module):
    def __init__(self, numChannels):
        super(Res, self).__init__()
        self.conv1 = nn.Conv2d(numChannels, numChannels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(numChannels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(numChannels, numChannels, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(numChannels)
    def forward(self, x):
        residual = x
        output = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
        return residual + output

In [None]:
class StyleNet(nn.Module):
    def __init__(self):
        super(StyleNet, self).__init__()
        self.pad = nn.ReflectionPad2d(4)
        self.conv1 = nn.Conv2d(3, 32, 9, stride=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.res1 = Res(128)
        self.res2 = Res(128)
        self.res3 = Res(128)
        self.res4 = Res(128)
        self.res5 = Res(128)
        self.deconv1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.conv4 = nn.Conv2d(32, 3, 9, stride=1, padding=4)
    def forward(self, x):
        x = self.conv3(self.conv2(self.conv1(self.pad(x))))
        x = self.res5(self.res4(self.res3(self.res2(self.res1(x)))))
        x = self.conv4(self.deconv2(self.deconv1(x)))
        return (F.tanh(x) + 1)/2

In [None]:
# load content image and style image
style_data = LoadImage('van.jpg')
Style_Gram = StyleGram(style_data)

In [None]:
stylenet = StyleNet()

In [None]:
optimizer = torch.optim.Adam(stylenet.parameters(), lr=1e-3)
best_loss = 1
alpha = 1
for j in range(2):
    for i, content_data in enumerate(loader):
        content_data = Variable(content_data)
        y_pred = stylenet(content_data)
#         content_loss = ContentLoss(y_pred, content_data)
#         style_loss = StyleLoss(y_pred, style_data)
#         loss = content_loss + style_loss * alpha
        loss = TotalLoss(y_pred, content_data)
#         if loss.data[0] < best_loss:
#             best_loss = loss.data[0]
#             torch.save(contentnet, 'Best_model_Res.dat')
        print(i, loss.data[0])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
ViewImage(stylenet(content_data))
ViewImage(content_data)
ViewImage(style_data)