In [4]:
import torch
import torchvision.models as models
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import glob
from torch.utils.data.dataset import Dataset

In [19]:
class COCODataset(Dataset):
    def __init__(self, path, transform):
        self.images = glob.glob(path)
        self.transform = transform
    def __getitem__(self, index):
        fname = self.images[index]
        img = Image.open(fname)
        return self.transform(img)
    def __len__(self):
        return len(self.images)

In [20]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.47, 0.45, 0.4), (0.24, 0.24, 0.24))])
dataset = COCODataset('Dataset/PreData/*.jpg', transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

In [35]:
for i, data in enumerate(loader):
    print i, data.shape
    data = Variable(data, requires_grad=False)
    ViewImage(data[2,:,:,:])
    if i > 4:
        break

0 torch.Size([4, 3, 256, 256])
1 torch.Size([4, 3, 256, 256])
2 torch.Size([4, 3, 256, 256])
3 torch.Size([4, 3, 256, 256])
4 torch.Size([4, 3, 256, 256])
5 torch.Size([4, 3, 256, 256])


In [30]:
def ViewImage(Data):
    # convert torch tensor back into image and view it
    data = torch.squeeze(Data.data)
    f = transforms.ToPILImage()
    img = f(data)
    img.show()

In [None]:
def Gram(Fmap):
    # calculate Gram matrix from feature map
    f = Fmap.squeeze()
    numFeatures = f.size()[0]
    W = f.size()[1]
    H = f.size()[2]
    f = f.view(numFeatures, W*H)
    return torch.mm(f, torch.t(f)) / (numFeatures*W*H)

In [None]:
# load pretrained VGG16 network
vgg = models.vgg16(True)
vgg_conv = vgg.features[0]
for paras in vgg_conv.parameters():
    paras.requires_grad = False

In [None]:
def StyleLoss(input1, input2):
    Loss = Variable(torch.zeros(1))
    layers = [0, 2, 5, 7]
    for i in range(max(layers)+1):
        layer = vgg.features[i]
        input1 = layer(input1.clone())
        input2 = layer(input2.clone())
        if i in layers:
            Loss = Loss + torch.sum((Gram(input1) - Gram(input2))**2)
    return Loss

In [None]:
class Res(nn.Module):
    def __init__(self, numChannels):
        super(Res, self).__init__()
        self.conv1 = nn.Conv2d(numChannels, numChannels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(numChannels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(numChannels, numChannels, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(numChannels)
    def forward(self, x):
        residual = x
        output = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
        return residual + output

In [None]:
class ContentNet(nn.Module):
    def __init__(self):
        super(ContentNet, self).__init__()
        self.pad = nn.ReflectionPad2d(4)
        self.conv1 = nn.Conv2d(3, 32, 9, stride=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.res1 = Res(128)
        self.res2 = Res(128)
        self.res3 = Res(128)
        self.res4 = Res(128)
        self.res5 = Res(128)
        self.deconv1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.conv4 = nn.Conv2d(32, 3, 9, stride=1, padding=4)
    def forward(self, x):
        x = self.conv3(self.conv2(self.conv1(self.pad(x))))
        x = self.res5(self.res4(self.res3(self.res2(self.res1(x)))))
        x = self.conv4(self.deconv2(self.deconv1(x)))
        return (F.tanh(x) + 1)/2

In [None]:
# load content image and style image
content_data = LoadImage('van.jpg')
style_data = LoadImage('van.jpg')

In [None]:
random_data = Variable(torch.randn(1, 3, 256, 256))

In [None]:
raw_vgg_feature = vgg_conv(content_data)
loss_fn = nn.MSELoss()

In [None]:
contentnet = ContentNet()

In [None]:
optimizer = torch.optim.Adam(contentnet.parameters(), lr=1e-3)
best_loss = 0.00223590899258852
alpha = 1
for i in range(50000):
    x = content_data
    y_pred = contentnet(x)
#     loss1 = loss_fn(vgg_conv(y_pred), raw_vgg_feature)
    loss2 = StyleLoss(contentnet(x), style_data)
#     loss = loss1 + loss2 * alpha
    loss = loss2
    if i % 5 == 0:
        if loss.data[0] < best_loss:
            best_loss = loss.data[0]
            torch.save(contentnet, 'Best_model_Res.dat')
#         print(i, loss1.data[0], loss2.data[0], loss.data[0])
        print(i, loss.data[0])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
ViewImage(contentnet(content_data))
ViewImage(content_data)
ViewImage(style_data)