In [0]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models, transforms

from PIL import Image
import numpy as np

import matplotlib.pyplot as plt
from IPython.display import clear_output, display

np.set_printoptions(precision=2)
plt.rcParams['figure.figsize'] = (14.0, 10.0)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [0]:
# https://pytorch.org/docs/stable/torchvision/models.html
#normalize with provided mean and std for pretrained models
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def img2tensor(img):
    normalize = transforms.Normalize(mean=mean, std=std)

    tfms = transforms.Compose([
        transforms.Resize(512),
        transforms.ToTensor(),
        normalize
    ])
    
    img = tfms(img)
    img = img.unsqueeze(0)
    
    return img

def tensor2img(tensor):
    img = tensor.clone().detach().cpu().numpy().squeeze().transpose(1, 2, 0)
    img = img * np.array(std) + np.array(mean)
      
    return img

content_img = Image.open('ana.jpg').convert('RGB') 
style_img = Image.open('ana.jpg').convert('RGB') 

style_tensor = img2tensor(style_img)
content_tensor = img2tensor(content_img)

In [0]:
display(style_img)
display(content_img)
display(style_img.size, content_img.size)

In [0]:
model = models.vgg19(pretrained = True).features

for p in model.parameters():
    p.requires_grad_(False)

In [0]:
model = model.to(device)
style_tensor = style_tensor.to(device)
content_tensor = content_tensor.to(device)

In [0]:
plt.figure()
plt.imshow(tensor2img(style_tensor))
plt.figure()
plt.imshow(tensor2img(content_tensor))

In [47]:
style_layers = [0, 5, 10, 19, 28]
style_weights = [1, 0.75, 0.5, 0.35, 0.25, 0.15]
content_layers = [21]

layers = sorted(style_layers + content_layers)
display(layers)


def extract_features(x, model):
    features = {}

    for i, (name, layer) in enumerate(model._modules.items()):
        x = layer(x)
        if i in layers:
            features[i] = x
    
    return features


def calc_gram_matrix(tensor):
    _, channels, height, width = tensor.size()
    tensor = tensor.view(channels, height * width)
    gram_m = torch.mm(tensor, tensor.t())
    gram_m = gram_m.div(channels * width * height)
    
    return gram_m
    
style_ftrs = extract_features(style_tensor, model)
style_ftrs_gram_m = { layer: calc_gram_matrix(style_ftrs[layer]) for layer in style_ftrs }
content_ftrs = extract_features(content_tensor, model)


display(style_ftrs_gram_m[0].dtype)
display(content_ftrs.keys())

[0, 5, 10, 19, 21, 28]

torch.float32

dict_keys([0, 5, 10, 19, 21, 28])

In [0]:
images = []

In [0]:
epochs = 500
optimizer = optim.Adam([target], lr=1e-2)
style_loss_weight = 1e6


for epoch in range(epochs):
    target_ftrs = extract_features(target, model)
    
    content_loss = 0
    for l in content_layers:
        # m = torch.sin(torch.arange(content_ftrs[l].numel(), dtype=torch.float64).view(content_ftrs[l].shape) + (epoch/epochs) * np.pi).to(device)
        # content_ftrs[l] = content_ftrs[l] * m
        content_loss += F.mse_loss(target_ftrs[l], content_ftrs[l])
        
    style_loss = 0
    for l, w in zip(style_layers, style_weights):
        target_ftrs_gram_m = { layer: calc_gram_matrix(target_ftrs[layer]) for layer in style_ftrs }
        style_loss += F.mse_loss(target_ftrs_gram_m[l], style_ftrs_gram_m[l]) * w
        
        
    total_loss = content_loss + style_loss * style_loss_weight

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
 
    if epoch % 10 == 0:
        print(f'epoch {epoch}', style_loss.item(), style_loss.item() * style_loss_weight, content_loss.item())
 
    if epoch % 100 == 0:
      img = tensor2img(target)
      images.append(img)


epoch 0 0.0 0.0 0.0


In [0]:
with torch.no_grad():
    plt.figure()
    plt.imshow(tensor2img(target))