In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image

In [2]:
content_image = Image.open("./dataset/gan-getting-started/photo/photo_jpg/000910d219.jpg")
style_image = Image.open("./dataset/gan-getting-started/monet/monet_jpg/000c1e3bff.jpg")

In [3]:
def load_image(image):
    image = transforms.ToTensor()(image)
    image = image.unsqueeze(0)
    image = image.to(device)
    return image

def normalize_image(image):
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device)
    image = (image - mean[:, None, None]) / std[:, None, None]
    return image

In [4]:
device = torch.device('cuda')
vgg = models.vgg19(pretrained=True).features.to(device).eval()

def get_features(image, model):
    layers = {
        '0': 'conv1_1',
        '5': 'conv2_1',
        '10': 'conv3_1',
        '19': 'conv4_1',
        '28': 'conv5_1'
    }
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
    return features

content = load_image(content_image)
style = load_image(style_image)
content_features = get_features(normalize_image(content), vgg)
style_features = get_features(normalize_image(style), vgg)



In [5]:
def gram_matrix(tensor):
    _, c, h, w = tensor.size()
    tensor = tensor.view(c, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram

In [6]:
target = content.clone().requires_grad_(True)

In [7]:
content_weight = 1
style_weight = 10
optimizer = optim.Adam([target], lr=0.002)
steps = 10000

In [None]:
for i in range(steps):
    target_features = get_features(normalize_image(target), vgg)
    content_loss = torch.mean((target_features['conv5_1'] - content_features['conv5_1']) ** 2)
    style_loss = 0
    for layer in style_features:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        style_gram = gram_matrix(style_features[layer])
        layer_loss = torch.mean((target_gram - style_gram) ** 2)
        style_loss += layer_loss
    total_loss = content_weight * content_loss + style_weight * style_loss
    optimizer.zero_grad()
    total_loss.backward(retain_graph=True)
    optimizer.step()
    if i % 100 == 0:
        print("Step [{}/{}], Total Loss: {:.4f}, Content Loss: {:.4f}, Style Loss: {:.4f}"
              .format(i + 1, steps, total_loss.item(), content_loss.item(), style_loss.item()))
        output = target.detach().clone().cpu()
        output = output.squeeze(0)
        output = transforms.ToPILImage()(output)
        output.save("output/output-{}.jpg".format(i))

Step [1/10000], Total Loss: 2923044864.0000, Content Loss: 0.0000, Style Loss: 292304480.0000
Step [101/10000], Total Loss: 188212960.0000, Content Loss: 1.8690, Style Loss: 18821296.0000
Step [201/10000], Total Loss: 119511128.0000, Content Loss: 2.0679, Style Loss: 11951113.0000
Step [301/10000], Total Loss: 94424560.0000, Content Loss: 2.1746, Style Loss: 9442456.0000
Step [401/10000], Total Loss: 78702088.0000, Content Loss: 2.2341, Style Loss: 7870208.5000
Step [501/10000], Total Loss: 65855664.0000, Content Loss: 2.2863, Style Loss: 6585566.0000
Step [601/10000], Total Loss: 54105988.0000, Content Loss: 2.3231, Style Loss: 5410598.5000
Step [701/10000], Total Loss: 43633892.0000, Content Loss: 2.3469, Style Loss: 4363389.0000
Step [801/10000], Total Loss: 35300180.0000, Content Loss: 2.3680, Style Loss: 3530017.7500
Step [901/10000], Total Loss: 29145512.0000, Content Loss: 2.3878, Style Loss: 2914551.0000
Step [1001/10000], Total Loss: 24690344.0000, Content Loss: 2.4068, Style 