## Import Dependencies

In [None]:
import os
import copy

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
# import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms

## Set Up

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' # GPU can accomodate greater image sizes and faster image synthesis

## Load and Process Images

In [None]:
image_dir = "./static"

In [None]:
def load_image(image_path, size=(128, 128)): # using default size for CPU
    image = Image.open(image_path)
    
    loader = transforms.Compose([
        transforms.Resize(size),  # resize
        transforms.ToTensor()])  # transform image into tensor

    image = loader(image).unsqueeze(0) # models in torch.nn requires inputs with a batch dimension
    return image.to(device, torch.float)


def show_image(tensor, title=None):
    image = tensor.cpu().clone()  # clone the tensor to avoid modifying the original

    # print("tensor to display: ", image)
    
    if image.dim() == 4:
        image = image.squeeze(0) # remove artificial batch dimension
    
    image = transforms.ToPILImage()(image)
    plt.imshow(image)

    if title:
        plt.title(title)
    plt.pause(0.01)

In [None]:
image_size = (128, 128)
content_image = load_image(os.path.join(image_dir, "content_image.jpeg"), image_size)
style_image = load_image(os.path.join(image_dir, "style_reference.jpeg"), image_size)

print('content image shape: ', content_image.size())
print('style image shape: ', style_image.size())

In [None]:
show_image(content_image, "Content Image")
show_image(style_image, "Style Image")

## Gram Matrix

In [None]:
# A Gram matrix is the result of multiplying a matrix by its transpose.
# Since our inputs will by pytorch tensors, we first need to manipulate the shape of the tensor to become a 2D matrix
# before performing matrix multiplication.

# Since the style features of an image are in the higher layers of the network, 
# the resulting matrix must be normalized to reduce influence of the first layers during gradient descent.


def gram_matrix(input):
    '''
    Computes the Gram matrix of a tensor.
    '''
    print("========calculating gram matrix========")

    print("original shape: ", input.size())

    if input.dim() == 3:
        input = input.unsqueeze(0)

    shape = input.size()
    print("input shape: ", shape)

    try:
        reshaped = input.view(shape[0] * shape[1], shape[2] * shape[3])  # flatten tensor into 2D
        print("flattened tensor shape: ", reshaped.size())
        
        # multiply matrix by its transpose to compute the gram product
        result = torch.mm(reshaped, reshaped.t())

        # normalize the matrix to scale each value within 0-1
        result = result.div(shape[0] * shape[1] * shape[2] * shape[3])

        print("Normalized matrix shape: ", result.size())
        return result

    except IndexError:
        print("Input tensor must have 4 dimensions. Received {}D tensor instead".format(len(shape)))

In [None]:
test_res = gram_matrix(style_image)

### Style Loss

In [None]:
# create the style and content loss as modules to add them into the model

class StyleLoss(nn.Module):
    
    def __init__(self, target):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target).detach()  # the target must be detached to compute the gradient since it needs to be a static value

    def forward(self, input):  # defines the computation (forward propagation) the module performs on the input
        # print('Style loss forward input, target: ', input, self.target)
        
        G = gram_matrix(input)
        # print('gram matrix: ', G)
        
        self.loss = F.mse_loss(G, self.target)
        return input

    # back propagation method is defined automatically

### Content Loss

In [None]:
class ContentLoss(nn.Module):

    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        # print('Content loss forward input, target: ', input, self.target)
        
        self.loss = F.mse_loss(input, self.target)
        return input

## Normalization

"All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]" (https://pytorch.org/vision/0.8/models.html)

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# normalization module to normalize input images for the VGG-19 model
class Normalize(nn.Module):
    def __init__(self, mean, std):
        # super(Normalization, self).__init__()
        super().__init__()
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

        print(self.mean)
        print(self.mean.size())

        print(self.std)
        print(self.std.size())

    def forward(self, input):
        return (input - self.mean) / self.std # normalization formula

## L-BFGS Optimizer

We will use the L-BFGS optimizer to optimize the input image's features to minimize the style/content loss. Notice how this differs from the usual application of optimizing model params. 

In [None]:
def lbfgs_optimizer(input):
    optimizer = torch.optim.LBFGS([input.requires_grad_()]) # records operations applied to the input image
    return optimizer

## Preparing the Model

The paper only uses the first module of the VGG-19 model (features), which contains the convolution and pooling layers to extract the content and style representations. The classifier module in not needed as we are not performing any image classification.

Note: we will use the model in evaluation mode (.eval()) since it may behave different in evaluation vs. training.

In [None]:
# see summary of model
!pip install torchsummary

In [None]:
from torchsummary import summary

In [None]:
vgg_network = models.vgg19(weights='DEFAULT').features.to(device).eval()
summary(vgg_network, (3, 224, 224))

In [None]:
# layers used in the paper to determine content/style representations
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

last_layer = 'conv_5'  # Losses stop being calculated after this layer

def compute_losses(network, style_image, content_image, mean=mean, std=std, content_layers=content_layers, style_layers=style_layers):
    print('content image shape: ', content_image.size())
    print('style image: shape ', style_image.size())
    
    vgg_network = copy.deepcopy(network)
    normalization = Normalize(mean, std).to(device)

    # make a new sequential model with custom loss/normalization modules
    nst_model = nn.Sequential(normalization)

    style_losses = []
    content_losses = []

    idx = 0
    for layer in vgg_network.children():

        # check type of layer
        if isinstance(layer, nn.Conv2d):
            name = 'conv_'
            idx += 1
        elif isinstance(layer, nn.ReLU):
            name = 'relu_'
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_'
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_'
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        name += str(idx)
        nst_model.add_module(name, layer)

        if name in content_layers:
            target = nst_model(content_image).detach()
            print('target shape: ', target.size())
            
            content_loss = ContentLoss(target)
            print('content loss: ', content_loss)
            
            nst_model.add_module("content_loss_{}".format(idx), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            target = nst_model(style_image).detach()
            print('target shape: ', target.size())
            
            style_loss = StyleLoss(target)
            print('style loss: ', style_loss)
            
            nst_model.add_module("style_loss_{}".format(idx), style_loss)
            style_losses.append(style_loss)

        if name == last_layer:
            break

    

    # # for i in range(len(new_model) - 1, -1, -1):
    # for i in reversed(len(new_model) - 1):
    #     if isinstance(new_model[i], ContentLoss) or isinstance(new_model[i], StyleLoss):
    #         break  # we don't need the layers after the last content/style loss is computed

    # new_model = new_model[:(i + 1)]

    return nst_model, style_losses, content_losses

## Preparing Images

In [None]:
input_image = torch.randn((128, 128), device=device)  # we will use a 128x128 whitenoise image as input

plt.figure()
show_image(input_image, title='Input Image')

## Apply the Algorithm

In [None]:
def style_transfer(network, input, content_image, style_image, mean=mean, std=std, 
                   steps=100, style_weight=20000, content_weight=100):

    # get model, losses, and optimizer
    model, style_losses, content_losses = compute_losses(network, style_image, content_image)
    optimizer = lbfgs_optimizer(input)
    
    i = 0
    while i <= steps:

        def closure():
            # correct the values of updated input image
            input_image.data.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_image)
            
            style_loss = 0
            content_loss = 0

            for loss in style_losses:
                style_loss += loss.loss
            for loss in content_losses:
                content_loss += loss.loss

            style_loss *= style_weight
            content_loss *= content_weight

            total_loss = style_loss + content_loss
            print('total loss: ', total_loss)
            
            total_loss.backward()  # Compute the gradient
            print('total loss after back propagation: ', total_loss)

            # i += 1
            # if i % 50 == 0:
            #     print("{}/{} steps:".format(i, steps))
            #     print('Style Loss : {:4f} Content Loss: {:4f}'.format(
            #         style_loss.item(), content_loss.item()))
            #     print()

            return total_loss

        optimizer.step(closure)
        i += 1

    # a last correction...
    input.data.clamp_(0, 1)

    return input

In [None]:
result = style_transfer(vgg_network, input_image, content_image, style_image)

plt.figure()
imshow(result, title='Output')