# Load packages

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as dset
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image

In [2]:
device = torch.device('cuda')

# Load image data

## Dataset with transform

In [3]:
transform_normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

In [4]:
transform_tensor = transforms.ToTensor()

In [5]:
style_data = Subset(dset.ImageFolder("./dataset/gan-getting-started/monet/", transform_normalize), range(150))
content_data = Subset(dset.ImageFolder("./dataset/gan-getting-started/photo/", transform_normalize), range(1))
# target_data = Subset(dset.ImageFolder("./dataset/gan-getting-started/photo/", transform_tensor), range(1))

In [6]:
def normalize_image(image):
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device)
    image = (image - mean[:, None, None]) / std[:, None, None]
    return image

In [7]:
target = content_data[0][0].clone().requires_grad_(True)


## DataLoader

In [8]:
style_loader = DataLoader(style_data)
content_loader = DataLoader(content_data)

# Load VGG model

In [9]:
vgg = models.vgg19(pretrained=True).features.to(device).eval()



In [10]:
def gram_matrix(tensor):
    # batch, channel, height, width
    try:
        _, c, h, w = tensor.size()
    except:
        c, h, w = tensor.size()
    # squeeze batch dimension and change shape of tensor to dense
    tensor = tensor.view(c, h * w)
    # matrix multiplication
    # => channel by channel
    gram = torch.mm(tensor, tensor.t())
    return gram

In [11]:
# get feature of image using pre-trained model
def get_style_features(style_loader, model):
    layers = {
        '0':'conv1_1',
#         '5':'conv2_1',
#         '10':'conv3_1',
#         '19':'conv4_1',
#         '28':'conv5_1',
    }
    style_features = {}
    for i, (x, _) in enumerate(style_loader):
        x = x.to(device)
        style_features[f'style_{i}'] = {}
        for name, layer in model._modules.items():
            x = layer(x)
            if name in layers:
#                 style_gram = gram_matrix(x)
                style_features[f'style_{i}'][layers[name]] = x
    return style_features

In [12]:
style_features = get_style_features(style_loader, vgg)

In [13]:
def get_content_features(content_loader, model):
    layers = {
        '0':'conv1_1',
#         '5':'conv2_1',
#         '10':'conv3_1',
#         '19':'conv4_1',
#         '28':'conv5_1',
    }
    content_features = {}
    for i, (x, _) in enumerate(content_loader):
        x = x.to(device)
        content_features[f'content_{i}'] = {}
        for name, layer in model._modules.items():
            x = layer(x)
            if name in layers:
#                 content_gram = gram_matrix(x)
                content_features[f'content_{i}'][layers[name]] = x
    return content_features

In [14]:
content_features = get_content_features(content_loader, vgg)

In [15]:
def get_target_features(target, model):
    target = target.to(device)
    target = normalize_image(target)
    layers = {
        '0':'conv1_1',
#         '5':'conv2_1',
#         '10':'conv3_1',
#         '19':'conv4_1',
#         '28':'conv5_1',
    }
    target_features = {}
    for name, layer in model._modules.items():
        target = layer(target)
        if name in layers:
#             target_gram = gram_matrix(target)
            target_features[layers[name]] = target
    
    return target_features

In [16]:
target_features = get_target_features(target, vgg)

# Training

In [17]:
steps = 1000
content_weight = 1
style_weight = 100000
optimizer = optim.Adam([target], lr=0.01)

In [18]:
for i in tqdm(range(1000)):
    content_loss = torch.mean((target_features['conv1_1'] - content_features['content_0']['conv1_1'])**2)
    style_loss = 0
    for i in range(150):
        for s in style_features:
            target_feature = target_features['conv1_1']
#             target_feature.to(device)
            target_gram = gram_matrix(target_feature)
            style_feature = style_features[f'style_{i}']['conv1_1']
#             style_feature.to(device)
            style_gram = gram_matrix(style_feature)
            layer_loss = torch.mean((target_gram - style_gram)**2)
            style_loss += layer_loss

    total_style_loss = style_weight * style_loss / 150
    total_loss = content_weight * content_loss + style_weight * total_style_loss
    optimizer.zero_grad()
    total_loss.backward(retain_graph=True)
    optimizer.step()
    if i % 100 == 0:
        print("Step [{}/{}], Total Loss: {:.4f}, Content Loss: {:.4f}, Style Loss: {:.4f}"
              .format(i + 1, steps, total_loss.item(), content_loss.item(), style_loss.item()))
        output = target.detach().clone().cpu()
        output = output.squeeze(0)
        output = transforms.ToPILImage()(output)
        output.save("output/output-{}.jpg".format(i))

KeyboardInterrupt: 