In [1]:
import torch
import torch.nn as nn
import torchvision # for pretrained models
from torchvision import transforms, models # for pretrained models
from PIL import Image # Python Image Library for image processing
import matplotlib.pyplot as plt # for plotting
import numpy as np # for numerical calculations

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def get_image(path, img_transform, size = (300, 300)):
    img = Image.open(path)
    img = img.resize(size, Image.LANCZOS)
    img = img_transform(img).unsqueeze(0) # add batch dimension
    return img.to(device)

In [4]:
def get_gram(m):
    """ m is of shape (batch_size, channels, height, width) """
    batch_size, channels, height, width = m.size()
    m = m.view(batch_size * channels, height * width)
    gram = torch.mm(m, m.t())
    return gram.div(batch_size * channels * height * width)

In [5]:
# Denormalize the image
def denormalize_img(img):
    img = img.numpy().transpose(1, 2, 0) # (channels, height, width) -> (height, width, channels)
    mean = np.array([0.485, 0.456, 0.406]) # mean of the ImageNet dataset
    std = np.array([0.229, 0.224, 0.225]) # standard deviation of the ImageNet dataset
    img = std * img + mean # denormalize
    img = np.clip(img, 0, 1) # clip the values to [0, 1]
    return img * 0.5 + 0.5

![figure](https://user-images.githubusercontent.com/30661597/107026142-96fa0100-67aa-11eb-9f71-4adce01dd362.png)

In [6]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.selected_layers = [3, 8, 15, 22 ] # layers to extract features from
        self.vgg = models.vgg16(pretrained = True).features # pretrained VGG16 model

    def forward(self, x):
        features = []
        for layer_num, layer in self.vgg._modules.items():
            x = layer(x)
            if int(layer_num) in self.selected_layers:
                features.append(x)
        return features

In [8]:
# Image transformation
img_transform = transforms.Compose([
    transforms.ToTensor(), # convert to tensor
    transforms.Normalize(mean = [0.485, 0.456, 0.406], # normalize using the mean
                            std = [0.229, 0.224, 0.225]) # normalize using the standard deviation
])
content_img = get_image("images/content.jpg", img_transform)
style_img = get_image("images/style.jpg", img_transform)
generated_img = content_img.clone().requires_grad_(True) # clone the content image and set requires_grad to True

# Optimizer
optimizer = torch.optim.Adam([generated_img], lr = 0.003, betas = (0.5, 0.999))

# Encoder
encoder = FeatureExtractor().to(device) # put the encoder in the device

for p in encoder.parameters():
    p.requires_grad = False # we don't need to train the encoder

  img = img.resize(size, Image.LANCZOS)
