# Jupyter Note - Neural Style Transfer using VGG Network

# Import Packages

In [None]:
import sys
import os
import shutil

from typing import Tuple

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

import tqdm

# Print Library Version
print(f"Python version={sys.version}")
print(f"torch version={torch.__version__}")
print(f"torchvision version={torchvision.__version__}")
print(f"matplotlib version={matplotlib.__version__}")
print(f"tdqm version={tqdm.__version__}")

from tqdm import tqdm

# Define Necessary Modules

## Model Architecture

In [None]:
"""Model Definition"""
class NSTNetwork(nn.Module):
    def __init__(
        self,
        feature_extractor : nn.Module,
        style_layer_names : list[str],
        content_layer_names : list[str],
        use_avgpool : bool = False
    ):
        super().__init__()

        # Get Indices
        self.style_loss_indices = [i for i, _ in enumerate(style_layer_names)]
        self.content_loss_indices = [i for i, name in enumerate(style_layer_names) if name in content_layer_names]

        # Define Normalisation Function
        self.normalise = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

        slices : list[nn.Sequencial] = []
        slice = nn.Sequential()

        i = 0;
        for layer in feature_extractor.children():
            if isinstance(layer, nn.Conv2d):
                i += 1
                name = 'conv_{}'.format(i)
            elif isinstance(layer, nn.ReLU):
                name = 'relu_{}'.format(i)
                layer = nn.ReLU(inplace=False)
            elif isinstance(layer, nn.MaxPool2d):
                name = 'pool_{}'.format(i)
                layer = nn.AvgPool2d(layer.kernel_size, layer.stride, layer.padding) if use_avgpool else layer
            elif isinstance(layer, nn.BatchNorm2d):
                name = 'bn_{}'.format(i)
            else:
                raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

            slice.add_module(name, layer)
            
            if name in style_layer_names:
                slices.append(slice)
                slice = nn.Sequential()

        self.extractor = nn.Sequential()
        for i, slice in enumerate(slices,1):
            self.extractor.add_module(f"slice_{i}", slice)

    def forward(self, x) -> list[torch.Tensor]:
        x = self.normalise(x)
        feature_maps : list[torch.Tensor] = []
        for slice in self.extractor.children():
            x = slice(x)
            feature_maps.append(x)
        return feature_maps

## Gram Matrix for Loss function

In [None]:
# Define gram matrix
def gram_matrix(ip : torch.Tensor) -> torch.Tensor:
    num_batch, num_channels, height, width = ip.size()
    feats = ip.view(num_batch * num_channels, width * height)
    gram_mat = torch.mm(feats, feats.t())
    return gram_mat.div(num_batch * num_channels * width * height)

## Auxiliary Functions

### Model Preparaion

In [None]:
def prepare_model(
    device,
    feature_extractor : nn.Module,
    style_layer_names: list[str],
    content_layer_names: list[str],
    use_avgpool : bool = False
) -> nn.Module:
    # Define Our Model
    net = NSTNetwork(
        feature_extractor=feature_extractor,
        style_layer_names=style_layer_names,
        content_layer_names=content_layer_names,
        use_avgpool=use_avgpool
    )

    # Disable Gradient and Turn Model to Evaluation Model
    net.requires_grad_(False)
    net.eval()
    net.to(device)

    return net

### Import image and convert to tensor

In [None]:
BIG_DIM=512
SMALL_DIM=128
image_dimension = BIG_DIM if torch.cuda.is_available() else SMALL_DIM

def image_to_tensor(image_filepath : str, image_dimension : int = SMALL_DIM) -> torch.Tensor:
    img = Image.open(image_filepath).convert('RGB')

    print(f"Original image size: {img.size}")

    # display image to check
    _, axs = plt.subplots(1,2, figsize=(10, 6))
    axs[0].set_title(f"{image_filepath}")
    axs[0].imshow(img)

    # Central-crop the image if it is not square
    if img.height != img.width:
        width, height = img.size
        min_dim = min(width, height)
        left = (width - min_dim) / 2
        top = (height - min_dim) / 2
        right = (width + min_dim) / 2
        bottom = (height + min_dim) / 2
        box = (left, top, right, bottom)
        img = img.crop(box)

    # Scale-up image if it is too small
    if img.height < image_dimension or img.width < image_dimension:
      scaling_factor = image_dimension / max(img.size)

      new_width = int(img.width * scaling_factor)
      new_height = int(img.height * scaling_factor)

      img = img.resize((new_width, new_height), Image.LANCZOS)

    print(f"New image size: {img.size}")

    torch_transformation = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_dimension),
        torchvision.transforms.ToTensor()
    ])

    img = torch_transformation(img).unsqueeze(0)

    # Display Processed Image, Sub plt
    axs[1].set_title(f"{image_filepath} Processed")
    axs[1].imshow(img.squeeze(0).cpu().detach().numpy().transpose(1,2,0))

    return img.to(torch.float)

## Style Transfering Process

In [None]:
def style_transfer(
    # neural network
    net : nn.Module,
    # Inputs
    input_image : torch.Tensor,
    content_image : torch.Tensor,
    style_image : torch.Tensor,

    # Optimiser
    lr : float,

    # loss function
    wt_style : float,
    wt_content : float,

    # Transfering Process
    num_epochs : int,
    loss_saving_freq : int,
    img_saving_freq : int,

    output_path : str
) -> Tuple[list[float], list[float]]:

    # Clean Output Directory
    if os.path.exists(output_path):
        shutil.rmtree(output_path) # Deletes the directory and all its contents
    os.makedirs(output_path) # Re-creates the empty directory

    input_image.requires_grad_(True)

    opt = optim.LBFGS([input_image], lr=lr)

    epoch_style_losses = []
    epoch_content_losses = []

    for curr_epoch in range(1, num_epochs+1):

        input_image.data.clamp_(0, 1)

        opt.zero_grad()

        epoch_style_loss = 0
        epoch_content_loss = 0

        x = input_image
        yc = content_image.detach()
        ys = style_image.detach()

        feature_maps_x = net(x)
        with torch.no_grad():
            feature_maps_yc = net(yc)
            feature_maps_ys = net(ys)

        for i,(f_x,f_yc,f_ys) in enumerate(zip(feature_maps_x,feature_maps_yc,feature_maps_ys)):
            if i in net.style_loss_indices:
                epoch_style_loss += F.mse_loss(gram_matrix(f_x), gram_matrix(f_ys.detach()).detach())
            if i in net.content_loss_indices:
                epoch_content_loss += F.mse_loss(f_x, f_yc.detach())

        epoch_style_loss *= wt_style
        epoch_content_loss *= wt_content

        total_loss = epoch_style_loss + epoch_content_loss
        total_loss.backward()

        def closure() -> torch.Tensor:
            return total_loss

        if curr_epoch % loss_saving_freq == 0:
            epoch_style_losses += [epoch_style_loss.cpu().detach().numpy()]
            epoch_content_losses += [epoch_content_loss.cpu().detach().numpy()]
            print(f"epoch number {curr_epoch}")
            print(f"style loss = {epoch_style_loss:.4f}, content loss = {epoch_content_loss:.4f}")

        if curr_epoch % img_saving_freq == 0:
            display_image = input_image.data.clamp_(0, 1).squeeze(0).cpu().detach()
            plt.figure()
            plt.title(f"epoch number {curr_epoch}")
            plt.imshow(display_image.numpy().transpose(1,2,0))
            plt.show()
            torchvision.utils.save_image(
                display_image,
                f"{output_path}/image_{curr_epoch}.jpg"
            )

        opt.step(closure=closure)

    return (epoch_style_losses, epoch_content_losses)

# Google Drive Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# Change the current working directory
YOUR_WORKING_DIR_GDRIVE="MyDrive/Colab/neural-style-transfer/notebook"

import os;
os.chdir(f"/content/gdrive/{YOUR_WORKING_DIR_GDRIVE}")
print(os.getcwd())

# Main Functions

## Select Device

In [None]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
print(f"Using device: {device}")

## Create Directories if not exist

In [None]:
# Create Directory
INPUT_PATH="./inputs"
OUTPUT_PATH="./outputs"
os.makedirs(INPUT_PATH, exist_ok=True) # Input Directory
os.makedirs(OUTPUT_PATH, exist_ok=True) # Output Directory

## Prepare model

### Import pretrained model(s)

In [None]:
vgg19_model = torchvision.models.vgg19(weights=torchvision.models.vgg.VGG19_Weights.DEFAULT)
# print(vgg19_model)

vgg16_model = torchvision.models.vgg16(weights=torchvision.models.vgg.VGG16_Weights.DEFAULT)
print(vgg16_model)

### Build Our Model

In [None]:
feature_extractor=vgg19_model.features
style_layer_names=["relu_1", "relu_2", "relu_3", "relu_4", "relu_5"]
content_layer_names=["relu_4"]
use_avgpool=False

net = prepare_model(
    device,
    feature_extractor=feature_extractor,
    style_layer_names=style_layer_names,
    content_layer_names=content_layer_names,
    use_avgpool=use_avgpool
)

In [None]:
# Print Network Aritechture
print(net)

## Prepare Image Tensors

In [None]:
# Get Style and Content Tensors
style_image = image_to_tensor(f"{INPUT_PATH}/style-1.jpg", image_dimension).to(device).detach()
content_image = image_to_tensor(f"{INPUT_PATH}/content-4.jpeg", image_dimension).to(device).detach()
print(f"style_image.shape: {style_image.shape}")
print(f"content_image.shape: {content_image.shape}")

In [None]:
# Get Input Tensor
init_mode = "random"

if init_mode == "content":
    # initialize as the content image
    input_image = content_image.clone().to(device)
else:
    input_image = torch.randn(content_image.data.size(), device=device)

# Display input image
plt.figure()
plt.title("Input Image")
plt.imshow(input_image.squeeze(0).cpu().detach().numpy().transpose(1,2,0).clip(0,1));

## Transfer the image

In [None]:
%%time

lr=0.5

wt_style=1e5
wt_content=2

num_epochs= 1000
loss_saving_freq = 10
img_saving_freq = 100

epoch_style_losses, epoch_content_losses = style_transfer(
    # Neural Network
    net,

    # Inputs
    input_image,
    content_image,
    style_image,

    # Optimiser
    lr,

    # loss function
    wt_style,
    wt_content,

    # Transfering Process
    num_epochs,
    loss_saving_freq,
    img_saving_freq,

    OUTPUT_PATH
)

### Plot Loss Curve for further analysis

In [None]:
plt.plot(range(loss_saving_freq, num_epochs+1, loss_saving_freq), epoch_style_losses, label='style_loss');
plt.plot(range(loss_saving_freq, num_epochs+1, loss_saving_freq), epoch_content_losses, label='content_loss');
plt.legend();