# Group name: ClosedAI
Project ID: 043

Project: Images Style Transfer


| Name    | Email | Student ID |
| -------- | ------- | ------- |
| Elijah Maron  | z5372352@ad.unsw.edu.au   | z5372352
| Hari Birudavolu | z5419889@ad.unsw.edu.au     | z5419889
| Michael Girikallo    | z5416925@ad.unsw.edu.au   | z5416925
| Tianshuo Xu    | z5358205@ad.unsw.edu.au   | z5358205
| Vincent Pham    | z5363266@ad.unsw.edu.au   | z5363266

Codebase available on [GitHub](https://github.com/teddyld/image-style-transfer)

# Table of Contents

1. [Introduction](#introduction)
2. [Motivation](#motivation)
3. [Problem Statement](#problem-statement)
4. [Data Sources](#data-sources)
5. [Exploratory Analysis of Data](#exploratory-analysis)
6. [Models and Methods](#models-methods)
7. [Results](#results)
8. [Discussion](#discussion)

# 1. <a id="introduction">Introduction</a>

# 2. <a id="motivation">Motivation</a>

# 3. <a id="problem-statement">Problem Statement</a>

# 4. <a id="data-sources">Data Sources</a>

This section first explains the process to prepare the datasets and describes the characteristics of the COCO2014 and WikiArt dataset used for Image Style Transfer.

## a. Imports

In [None]:
from torch.utils.data import Dataset, DataLoader
import os
import random
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import shutil
from PIL import Image
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt

# Path to data sources in disk
COCO2014_DATA_PATH = './data/coco2014/'
WIKIART_DATA_PATH = './data/wikiart/'
STYLE_PATH = './data/style'

# Maps the name of the style classes to integers
WIKIART_STYLE_MAP = {
    'Baroque': 1,
    'Cubism': 2,
    'Early_Renaissance': 3,
    'Pointillism': 4,
    'Ukiyo_e': 5,
}

## b. Data Preparation

After downloading our datasets from Kaggle, we prepared the style dataset from WikiArt by splitting the dataset into train, validation, and test splits. Note that this style dataset is a subset of the WikiArt dataset on Kaggle as it only contains five style classes: Baroque, Cubism, Early Renaissance, Pointillism, and Ukiyoe. We justify this choice later in [this section](#style-justification)

In [None]:
"""
Initial Data structure
/data
    /coco2014
        /annotations
        /images
            /test2014
            /train2014
            /val2014
        /labels
    /wikiart
        /Abstract_Expressionism
            /aaron-siskind_acolman-1-1955.jpg
            ...
        /Action_painting
        ...
        /Ukiyo_e
        /classes.csv
        /wclasses.csv
        
Target Data structure
/data
    /coco2014
        /annotations
        /images
            /test2014
            /train2014
            /val2014
        /labels
    /style
        /test
        /train
        /val
"""

### i) Flattening the WikiArt directory

In [None]:
def flatten_data(include=['Cubism', 'Baroque', 'Early_Renaissance', 'Pointillism', 'Ukiyo_e'], src=WIKIART_DATA_PATH):
    """
    Given a directory src, unpack the contents of its sub-directories into the directory src. Only the names of the sub-directories in the 'include' list are unpacked
    """
    subdirs = [dir for dir in os.listdir(src) if os.path.isdir(os.path.join(src, dir)) and dir in include]

    for dir in subdirs:
        subdir_path = os.path.join(src, dir)
        files = os.listdir(subdir_path)
        loop = tqdm(files)
        loop.set_description(f'Unpacking sub-directory {dir}')
        for file in loop:
            # Prepend the style class the file is from
            os.rename(os.path.join(subdir_path, file), os.path.join(src, dir + '_' + file))
            
        os.rmdir(subdir_path)
    
flatten_data()

### ii) Splitting the WikiArt dataset into train, validation, and test splits

In [None]:
def make_split(src, dest, files, ttv):
    """
    Write files from src to dest
    """
    loop = tqdm(files, total=len(files))
    loop.set_description(f"Writing {ttv} split")
    for file in loop:
        src_path = os.path.join(src, file)
        dest_path = os.path.join(dest, file)
        shutil.copyfile(src_path, dest_path)

def split_data(src, dest, split_size=0.8, max_files=30000, random_seed=42):
    """
    Given a directory src, create train, val, and test split directories in directory dest
    """
    random.seed(random_seed)
    if not (0 < split_size < 1):
        raise ValueError(f"split_size must be between 0 and 1. Got: {split_size}")
    
    all_files = [file for file in os.listdir(src) if file.endswith('.jpg')]
    
    if not all_files:
        raise ValueError(f"src directory did not contain any files")
    
    
    if len(all_files) > max_files:
        all_files = all_files[:max_files]
        
    random.shuffle(all_files)
    
    # Create train split
    train_files, remaining_files = train_test_split(all_files, train_size=split_size, random_state=random_seed)
    
    # Create validation and test split
    validation_files, test_files = train_test_split(remaining_files, test_size=0.5, random_state=random_seed)
    
    # Make destination directories
    split = ['train', 'val', 'test']
    
    for ttv in split:
        split_path = os.path.join(dest, ttv)
        if os.path.exists(split_path):
            shutil.rmtree(split_path)
        
        os.makedirs(split_path, exist_ok=True)
        
    make_split(src, os.path.join(dest, 'train'), train_files, 'train')
    make_split(src, os.path.join(dest, 'val'), validation_files, 'val')
    make_split(src, os.path.join(dest, 'test'), test_files, 'test')

split_data(WIKIART_DATA_PATH, STYLE_PATH, 0.8)

## c. <a id="style-justification">Data Source Characteristics</a>

With our data sources prepared, we now describe and [illustrate](#dataset-examples) the characteristics of our content and style datasets. For our content images, we use the COCO2014 dataset whilst for our style images we use the WikiArt dataset.

The COCO 2014 Dataset on [Kaggle](https://www.kaggle.com/datasets/jeffaudi/coco-2014-dataset-for-yolov3/) is a popular dataset developed by [Lin et al.](https://arxiv.org/pdf/1405.0312) with 80 classes, 82,783 training and 40,504 validation images in RGB format. During image collection, the authors filtered out iconic images in favour of non-iconic images. Iconic images are characterised by single large objects in a canonical perspective centered in the image. Evidently, research from [Torralba and Efros](https://ieeexplore-ieee-org.wwwproxy1.library.unsw.edu.au/document/5995347) indicates that a lack of contextual information and non-canonical viewpoints in iconic images may lead to decreased generalization of datasets from capture bias and negative set bias. Therefore, by removing iconic images, the COCO dataset is well-generalised and provides rich contextual relationships between objects in their natural environments. For the task of image style transfer, the COCO 2014 dataset is an exceptional source for 'content' images. Chiefly, the prevalance of non-iconic images provide diverse structural features which challenge the robustness and efficacy of style transfer model performance. For example, models like AdaAttn and MAST hyper-fixate on local structure leading to style leakage [(Xu et al.)](https://arxiv.org/abs/2304.00414). Contributing to the diverse structural features is the breadth of object classes collected for this dataset which will challenge a model's ability to generalize.

The WikiArt Dataset accessed from [Kaggle](https://www.kaggle.com/datasets/steubk/wikiart) is dataset of 80,020 images from 1119 different artists with 27 distinct styles classes. The images are sourced from WikiArt.org, an encylopedia of art. Combined, the 27 styles classes offer a wide range of unique salient features ranging from the small visible brushstrokes of 'Impressionism', the vibrant and bold colouring of 'Fauvism' to visually blending of small dots of colour which define 'Pointillism'. For the task of image style transfer, the WikiArt dataset is a standard benchmark for 'style' images used across research. The large number of classes alongside the depth of salient features makes WikiArt useful in evaluating the robustness and reliability of transferring artistic styles while preserving style patterns. The quantity of style classes will negatively affect model performance as well as dramatically increase the cost of computation. Hence, we reduce the number of style classes we consider to five: Cubism, Pointillism, Baroque, Early_Renaissance, and Ukiyo_e. We have selected these styles based on the criteria of uniqueness and within-class consistency of style.

### i) Dataset and DataLoaders

In [None]:
# Content dataset
class COCO2014(Dataset):
    def __init__(self, split, max_files, transform=None):
        if split not in ['train', 'val', 'test']:
            raise ValueError(f"split must be 'train', 'val', or 'test'. Got: {split}")
        
        split = split + '2014'
        self.image_path = os.path.join(COCO2014_DATA_PATH, 'images', split)
        images = os.listdir(self.image_path)
        
        if len(images) > max_files:
            images = images[:max_files]
        
        self.images = images
        self.length = len(images)
        self.transform = transform
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img = Image.open(os.path.join(self.image_path, img_name)).convert('RGB')
        
        if self.transform:
            img = self.transform(img)

        return img # Note that we do not need the label of the content classes

# Style dataset
class StyleDataset(Dataset):
    def __init__(self, ttv, transform=None):
        self.image_path = os.path.join(STYLE_PATH, ttv)
        self.images = os.listdir(self.image_path)
        self.transform = transform

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img = Image.open(os.path.join(self.image_path, img_name)).convert('RGB')

        # Get image label from file name
        img_style = [s for s in img_name.split('_') if s[0].isupper() or (s[0] == 'e' and len(s) == 1)]
        label = WIKIART_STYLE_MAP["_".join(img_style)]

        if self.transform:
            img = self.transform(img)
            
        return img, label

In [None]:
import torchvision.transforms as transforms

# For demonstration we use the Resize transform and plot images from the training split
train_tf = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor(),
])

content_trainset = COCO2014('train', 24000, train_tf)
style_trainset = StyleDataset('train', transform=train_tf)

content_trainloader = DataLoader(content_trainset, 64, shuffle=True)
style_trainloader = DataLoader(style_trainset, 64, shuffle=True)

### ii) <a id="dataset-examples">Examples from our Datasets</a>

In [None]:
def plot_dataset(loader, title, styleset=False):
    """
    Plot images from a batch of a DataLoader
    Args:
        loader (iterable) - a Pytorch DataLoader class
    """
    fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(10, 5), subplot_kw={'xticks': [], 'yticks': []})
    for batch in loader:
        for i, ax in enumerate(axes.flat):
            if styleset:
                images, labels = batch
                label = list(WIKIART_STYLE_MAP.keys())[list(WIKIART_STYLE_MAP.values()).index(labels[i])]
                ax.set_title(label)
            else:
                images = batch
            image = images[i].permute(1, 2, 0)
            ax.imshow(image)
        break

    fig.suptitle(title)
    plt.tight_layout()
    plt.show()
    
plot_dataset(content_trainloader, 'COCO2014 Data')
plot_dataset(style_trainloader, 'WikiArt Data', styleset=True)

# 5. <a id="exploratory-analysis">Exploratory Analysis of Data</a>

In our exploration of the data, we investigate the efficacy of evaluating the structural features of images using Canny and Sobel edge detection, illustrate the class distribution of style classes, and describe the characteristics and properties of each of our style classes  

## a. Structural features of images using Canny and Sobel edge detection

Canny edge detection is an edge detection technique that extracts structural information from images. We utilise the function `skimage.feature.canny()` implemented in the scikit-image library to qualitatively display the structural features of images. The function uses the following steps:

1. A Gaussian blur with standard deviation of the Gaussian kernel equal to `sigma` to reduce image noise
2. Sobel edge detection
3. Apply non-maximum suppression to remove pixels far from edges 
4. Hysteresis thresholding is applied, labelling all points above the high threshold value as edges and recursively labelling any point above the low threshold value connected to a labeled point as an edge.

Sobel edge detection is a convolution-based method used for image edge detection that approximates the gradient of the image intensity. By convolving with the 3x3 Sobel kernel, the algorithm estimates the gradient magnitude and direction between regions of low and high intensity, hence, emphasising the edges of objects.

By comparing edges of the image output of the style transfer with its content image source, we qualitatively evaluate the change in structural features. Indicators of performant style transfer are consistent global structure between the output and content input images by examining object alignment and shape. Furthermore, small local details in the content image should not be ignored. 

In [None]:
from skimage.feature import canny
from skimage.filters import sobel

# Get random image from content dataset
images = next(iter(content_trainloader))
random_index = random.randint(0, len(images) - 1)
img = images[random_index]

def plot_canny_edges(img, sigma=1):
  """
  Apply canny edge detection to the image with 'sigma' parameter
  """
  fig, ax = plt.subplots(1, 3, figsize=(15, 15))
  
  # Change view to (height, width, channels)
  img = img.permute(1, 2, 0)
  ax[0].imshow(img)
  
  # Convert tensor to numpy and single channel
  img = img.numpy()[:, :, 0]
  ax[1].imshow(canny(img, sigma=sigma), cmap="copper")
  
  ax[2].imshow(sobel(img), cmap="copper")

plot_canny_edges(img)

## b. The characteristics and properties of styles

In this section, we explain the characteristics and properties of the five style classes selected as well as illustrate these images by plotting them without alteration from the original source.

In [None]:
def plot_style(style):
    """
    Plot four random images with the provided style in the train directory
    """
    style_images = [img for img in os.listdir(os.path.join(STYLE_PATH, 'train')) if img.startswith(style)]

    fig, axes = plt.subplots(1, 4, figsize=(10, 5), subplot_kw={'xticks': [], 'yticks': []})
    fig.suptitle(f"Style - {style}", y=0.8)
    for i, img in enumerate(style_images[:4]):
        axes[i].imshow(Image.open(os.path.join(STYLE_PATH, 'train', img)).convert('RGB'))

    plt.show()

### i) Cubism

Cubism, popularised by artists such as Pablo Picasso and Georges Braque in the 20th century is a style of art characterised by its use of interweaving planes and lines to depict abstract objects. Notably, Cubism art lacks form, and often, the arrangement of simple planes merges the foreground and background. Artworks in Cubism use muted tones of blacks and grays or use bright and solid colours. 

The Cubism class contains 2235 images.

In [None]:
plot_style('Cubism')

### ii) Pointillism

Artworks in the style of Pointillism are defined by their distinct painting technique, utilising small dots of colour so that from a distance, they visually blend together to form a vibrant composition. 

The Pointillism class contains 513 images. 

In [None]:
plot_style('Pointillism')

### iii) Baroque

The Baroque style refers to artworks derived from Europe from the early 17th to mid-18th century. Artworks of this style are associated with deep colors, dramatic light, sharp shadows, and dark backgrounds. 

The Baroque class contains 4240 images.

In [None]:
plot_style('Baroque')

### iv) Early Renaissance

Artworks in the Early Renaissance period are characterised by the realistic depiction of human anatomy and space from mythology and religion. The paintings of this period use muted colours whilst sculptures of human forms are made of bronze and marble. 

The Early Renaissance class contains 1391 images.

In [None]:
plot_style('Early_Renaissance')

### v) Ukiyo-e

Ukiyo-e refers to a style of Japanese paintings and woodblock prints from the Edo period. The style is defined by its bold and flat brush strokes, asymmetric composition, and unusual graphical perspective. In stark contrast, the artwork is complimented by its colourfulness in depictions of flora and fauna. Furthermore, artworks of this style often contain text, written with black or red ink. The text stands prominently against the style's generous use of negative space as the artwork's background is often only painted with a single colour.

The Ukiyo-e class contains 1167 images.

In [None]:
plot_style('Ukiyo_e')

## c. Train style class distribution

In [None]:
import seaborn as sns
sns.set_theme(rc={'figure.figsize':(11.7,8.27)})

train_all_files = os.listdir(os.path.join('./data/style/train'))

style_frequency = {
    "Baroque": 0,
    "Cubism": 0,
    "Early_Renaissance": 0,
    "Pointillism": 0,
    "Ukiyo_e": 0,
}

for img_name in train_all_files:

    img_style = [s for s in img_name.split('_') if s[0].isupper() or (s[0] == 'e' and len(s) == 1)]
    img_style = "_".join(img_style)
    
    style_frequency[img_style] += 1

keys = list(style_frequency.keys())
vals = [style_frequency[k] for k in keys]

plot = sns.barplot(style_frequency, x=keys, y=vals, hue=keys)
plot.set_ylabel("Count")
plot.set_xlabel("Style Class")

# 6. <a id="models-methods">Models and Methods</a>

## a. Choice of Models

We trained two generative networks: MSG-Net and CycleGAN and evaluated two pretrained traditional CNN methods: AdaIN and SANet.

In [None]:
# Imports
import torch.nn as nn
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

# Models
import models.AdaIN as adain
import models.SANet as SANet
import models.CycleGAN as cyclegan 
import models.MSGNet as MSGNet

# Utils
import utils.data as data
from utils.eval import compute_ssim, plot_results, calc_content_loss, calc_style_loss, calculate_fid_from_dataset, plot_training_history


from tqdm import tqdm
from PIL import Image
import numpy as np
import itertools
import datetime
import time
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## b. Traditional CNNs

### i) AdaIN
  
Huang and Belongie propose [AdaIN](https://arxiv.org/abs/1703.06868), an image style transfer architecture that aims to improve performance compared to previous works by using an adaptive instance normalization mechanism.  

Initially, the content image and the style image are passed through a pretrained VGG encoder to extract their feature maps, this will give out content feature map and style feature map.   
 
#### VGG  

| ![vgg_architecture.png](./images/vgg_architecture.png) | 
|:--:| 
| *VGG Architecture* |

Visual Geometry Group (VGG) is network is a deep convolutional neural network architecture. It primarily improves image classification and feature extraction performance by using smaller convolution kernels and a deeper network structure. The fundamental building blocks of the VGG network are 3x3 convolutional kernels and 2x2 max pooling layers, small convolutional kernels can captures fine features, and stacking multiple convolutional layers mimics a larger receptive field. VGG enhances the expressiveness and accuracy of the model by increasing the network depth (more convolution layers). Deeper networks can learn more complex and advanced features. Downsampling is performed using a fixed-size 2x2 max pooling layer to gradually reduce the spatial dimension of the feature map while increasing the receptive field. This design makes it excellent in feature extraction, such as style transfer, image retrieval, etc.  

The extracted feature maps enter the AdaIN module. In this module, the mean and standard deviation of the content feature map are adjusted to match the mean and standard deviation of the style feature map, thus aligning the statistical properties of the content feature map with those of the style feature map. The adjusted feature map then passed through the decoder to generate an image with the target style. Finally, the generated image is passed through the VGG encoder again to extract features for calculating the loss.
  
To maintain the structure of the content images and characteristics of the style images, Huang and Belongie introduce content loss and style loss, content loss ensures that the content of the generated image remains similar to that of the input content image by comparing their encoded features. And style loss ensures that the style of the generated image matches that of the input style image by comparing the statistics (mean and variance) of their encoded features.  

### ii) SANet

Park and Lee propose [SANet](https://arxiv.org/abs/1812.02342v5), an image style transfer architecture which aims to improve performance compared to previous work using a novel style-attentional network and identity loss function. 

Initially, the content and style images are forwarded to two pretrained VGG-19 encoders. To combine the global and local style patterns, two style-attentional networks (SANets) learns the mapping between the content and style features. The proposed SANet is a modified version of the [self-attention mechanism](https://arxiv.org/abs/1706.03762) proposed by Vaswani et al. The self-attention mechanism captures the relative dependencies within an input image by attending to each position and calculating their weighted importance. The SANets takes the encoded VGG content and style feature maps and creates a mapping of the content and style images. The SANet layer also upsamples its intermediate style map output which improves the model's ability to learn local style patterns. Finally, a 3x3 convolution combines the feature maps. The stylised image is then synthesised by the decoder.

To maintain the structure of the content images and characteristics of the style images, Park and Lee propose an a novel identity loss function which unlike content and style losses, takes two of the same content or style images through the forward pass of the network. Thus, the identity loss evaluates how much the model deviates from the original characteristics of the content and style images.

| ![sanet_architecture.png](./images/sanet_architecture.png) | 
|:--:| 
| *SANet Architecture* |

Code adapted from https://github.com/GlebSBrykin/SANET/tree/master

## c. Generative Networks

### i) MSG-Net

Zhang and Dana propose [MSG-Net](https://openaccess.thecvf.com/content_ECCVW_2018/papers/11132/Zhang_Multi-style_Generative_Network_for_Real-time_Transfer_ECCVW_2018_paper.pdf), an image style transfer architecture which aims to improve on the limitations in flexibility of standard generative methods by using a Generative Network with novel CoMatch Layer and Upsample Convolution.

The MSG-Net is composed of a Siamese network which inputs the content and style images and shares weights with the encoder of the Transformer network. In the feed-forward pass, the CoMatch Layer embeds the style images with a 2D representation and learns to match the second-order feature statistics (Gram Matrix) of the style targets during training. This enables multi-style generation from a single feed-forward network. At the Transformation network's bottleneck, its architecture has been extended to use an Upsampling Convolution operation which improves image quality by removing checkerboard artifacts. This operation applies an integer stride convolution and outputs an upsampled featuremap. 

The loss network adopts the VGG architecture to minimise the mean-squared error of the content and style losses, reducing the perceptual difference between the input and output images of the generative network learning. 

| ![msgnet.png](./images/msgnet_architecture.png) | 
|:--:| 
| *MSG-Net Architecture* |

#### Setup

In [None]:
# DataLoaders

train_tf = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.CenterCrop(size=64),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])

val_tf = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor(),
])

content_trainloader, content_validloader, _, _, style_validloader, _ = data.get_dataloaders(bs=64, train_tf=train_tf, valid_tf=val_tf)

In [None]:
# Model definitions
style_model = MSGNet.Net(ngf=128)
optimizer = optim.Adam(params=style_model.parameters(), lr=1e-3)
mse_loss = torch.nn.MSELoss()

# Load pretrained vgg16
vgg = MSGNet.Vgg16()

def init_vgg16(model_folder='./models/output/MSG-Net/'):
    if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
        if not os.path.exists(os.path.join(model_folder, 'vgg16.pth')):
            assert ValueError("Pretrained weights must be prepared, download vgg16.t7 weights and convert to vgg16.pth first")
        
        pretrained = torch.load('./models/output/MSG-Net/vgg16.pth')
        for (src, dst) in zip(list(pretrained.items()), vgg.parameters()):
            dst.data[:] = src[1]
        torch.save(vgg.state_dict(), os.path.join('./models/output/MSG-Net/', 'vgg16_msg.pth'))
        
init_vgg16()

vgg.load_state_dict(torch.load('./models/output/MSG-Net/vgg16_msg.pth'))

style_loader = data.StyleLoader()

# Helper functions
def preprocess_batch(batch):
    batch = batch.transpose(0, 1)
    (r, g, b) = torch.chunk(batch, 3)
    batch = torch.cat((b, g, r))
    batch = batch.transpose(0, 1)
    return batch

def subtract_imagenet_mean_batch(batch):
    """Subtract ImageNet mean pixel-wise from a BGR image."""
    tensortype = type(batch.data)
    mean = tensortype(batch.data.size())
    mean[:, 0, :, :] = 103.939
    mean[:, 1, :, :] = 116.779
    mean[:, 2, :, :] = 123.680
    return batch - Variable(mean.to(device))

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

#### Run model training

In [None]:
from utils.earlystopper import EarlyStopper

def msgnet_train(style_model, optimizer, mse_loss, vgg, content_loader, style_loader, num_epochs=10, batch_size=5, save_model_dir='./models/output/MSG-Net/'):
    style_model.to(device)
    vgg.to(device)
    early_stopper = EarlyStopper()
    content_losses = []
    style_losses = []
    total_losses = []
    for e in range(num_epochs):
        style_model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        print('.' * 64)
        print(f"--- Epoch {e + 1}/{num_epochs} ---")
        pbar = tqdm(content_loader, leave=False)
        for batch_id, x in enumerate(pbar):
            pbar.set_description(f"Epoch [{e + 1}/{num_epochs}]")
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(preprocess_batch(x))
            x = x.to(device)

            style_v = style_loader.get(batch_id)
            style_model.setTarget(style_v)
        
            style_v = subtract_imagenet_mean_batch(style_v)
            features_style = vgg(style_v)
            
            gram_style = [gram_matrix(y) for y in features_style]

            y = style_model(x)
            xc = Variable(x.data.clone())

            y = subtract_imagenet_mean_batch(y)
            xc = subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)
            content_loss = 1.0 * mse_loss(features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = gram_matrix(features_y[m])
                gram_s = Variable(gram_style[m].data, requires_grad=False).repeat(batch_size, 1, 1, 1)
                style_loss += 5.0 * mse_loss(gram_y.unsqueeze(1), gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()
            
            if (batch_id + 1) % (4 * 500) == 0:
                # Save model
                style_model.eval()
                style_model.cpu()
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + \
                    str(time.ctime()).replace(' ', '_') + "_" + str(
                    1.0) + "_" + str(5.0) + ".model"
                save_model_path = os.path.join(save_model_dir, save_model_filename)
                torch.save(style_model.state_dict(), save_model_path)
                style_model.train()
                style_model.cuda()
                pbar.set_description("\nCheckpoint, trained model saved at", save_model_path)
        
        mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
            time.ctime(), e + 1, count, len(content_loader),
            agg_content_loss / (batch_id + 1),
            agg_style_loss / (batch_id + 1),
            (agg_content_loss + agg_style_loss) / (batch_id + 1)
        )
        print(mesg)
        
        content_losses.append(agg_content_loss / (batch_id + 1))
        style_losses.append(agg_style_loss / (batch_id + 1))
        agg_total_loss = agg_content_loss + agg_style_loss
        total_losses.append(agg_total_loss / (batch_id + 1))
        
        # Early stopping
        if early_stopper.early_stop(total_loss):
            print(f'Stopping early at Epoch {e + 1}, min val loss failed to decrease after {early_stopper.get_patience()} epochs')
            break
    
    # Save model
    style_model.eval()
    style_model.cpu()
    save_model_filename = "Final_epoch_" + str(num_epochs) + "_" + \
        str(time.ctime()).replace(' ', '_') + "_" + str(
        1.0) + "_" + str(5.0) + ".model"
    save_model_path = os.path.join(save_model_dir, save_model_filename)
    torch.save(style_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
    return content_losses, style_losses, total_losses

content_losses, style_losses, total_losses = msgnet_train(style_model, optimizer, mse_loss, vgg, content_trainloader, style_loader, num_epochs=300)

In [None]:
plot_training_history(content_losses, style_losses, total_losses)

#### Run model evaluation

In [None]:
# Load chkpt model for evaluation
style_model = MSGNet.Net(ngf=128)
model_dict = torch.load('./models/output/MSG-Net/train_3.model')
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
    if key.endswith(('running_mean', 'running_var')):
        del model_dict[key]
style_model.load_state_dict(model_dict, False)

In [None]:
from PIL import Image

def tensor_to_img(tensor, cuda=False):
    (b, g, r) = torch.chunk(tensor, 3)
    tensor = torch.cat((r, g, b))
    if cuda:
        img = tensor.clone().cpu().clamp(0, 255).numpy()
    else:
        img = tensor.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype('uint8')
    img = Image.fromarray(img)
    return img

def msgnet_eval(style_model, content_loader, style_loader):
    style_model.to(device)
    ssim_sum = 0.0
    fid_sum = 0.0
    running_content_loss, running_style_loss = 0.0, 0.0
    total_samples = 0
    with torch.no_grad():
        for content, style in zip(content_loader, style_loader):
            # Move content and style batches to device
            content_images = content
            content_images = content_images.to(device)
            
            style_images, style_labels = style
            style_images, style_labels = style_images.to(device), style_labels.to(device)
            
            # Create stylised images
            style_v = Variable(style_images)

            content_images = Variable(content_images)
            style_model.setTarget(style_v)
            
            output = style_model(content_images)
            
            # Compute quantitative evaluation metrics
            ssim_sum += compute_ssim(content_images, output)
            fid_sum += calculate_fid_from_dataset(content_images, output, device, dims=2048)
            running_content_loss += calc_content_loss(content_images, output)
            running_style_loss += calc_style_loss(style_images, output)
            
            # Convert output to PIL Image
            stylised_images = []
            for img in output:
                img = tensor_to_img(img, cuda=True)
                stylised_images.append(img)

            # Display qualitative evaluation metrics on first batch
            if total_samples == 0:
                plot_results(content_images, style_images, style_labels, stylised_images, nrows=5, model_name="MSG-Net", msgnet=True)

            total_samples += style_labels.size(0)
            
    avg_ssim = ssim_sum / total_samples
    avg_fid = fid_sum / total_samples
    avg_content_loss = running_content_loss / total_samples
    avg_style_loss = running_style_loss / total_samples
    return avg_ssim, avg_fid, avg_content_loss, avg_style_loss

msgnet_ssim, msgnet_fid, msgnet_content_loss, msgnet_style_loss = msgnet_eval(style_model, content_validloader, style_validloader)
print("--- MSG-Net results ---")
print(f"Average SSIM = {msgnet_ssim:.7f}")
print(f"Average FID = {msgnet_fid:.4f}")
print(f"Average content loss = {msgnet_content_loss:.4f}")
print(f"Average style loss = {msgnet_style_loss:.4f}")

### ii) CycleGAN

#### Hyper-parameter Definitions

In [None]:
cuda = torch.cuda.is_available()

# Hyper-parameters
lr = 0.0002
batch_size = 64
num_epochs = 100
b1 = 0.5
b2 = 0.999
decay_epoch = 25
sample_interval = 100
checkpoint_interval = 100
n_residual_blocks = 9
lambda_cyc = 10.0
lambda_id = 5.0

# Dataset
num_channels = 3
img_height = 64
img_width = 64
input_shape = (num_channels, img_height, img_width)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

#### Calculate mean and std of dataset

In [None]:
# Calculate mean and std of dataset
test_tf = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.ToTensor()
])
dataset = cyclegan.ImageDataset(transform=test_tf, unaligned=True)

total_mean_A, total_mean_B, total_std_A, total_std_B = 0.0, 0.0, 0.0, 0.0

for i in tqdm(range(0, len(dataset))):
    batch = dataset[i]
    img_A = batch['A']
    img_B = batch['B']
    
    mean_A = torch.mean(img_A, dim=(1, 2))
    mean_B = torch.mean(img_B, dim=(1, 2))
    std_A = torch.std(img_A, dim=(1, 2))
    std_B = torch.std(img_B, dim=(1, 2))
    total_mean_A += mean_A
    total_mean_B += mean_B
    total_std_A += std_A
    total_std_B += std_B
    
mean = (total_mean_A + total_mean_B) / (len(dataset) * 2)
std = (total_std_A + total_std_B) / (len(dataset) * 2)
print(mean) # tensor([0.4820, 0.4424, 0.3893]) 
print(std) # tensor([0.2064, 0.1953, 0.1880])

#### Transformations and DataLoaders

The ImageDataset is a paired dataset of random content and style image pairs

In [None]:
train_tf = transforms.Compose([
    transforms.Resize(int(img_height * 1.12), Image.BICUBIC),
    transforms.RandomCrop(size=(img_width, img_height)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4820, 0.4424, 0.3893), (0.2064, 0.1953, 0.1880)),
])

val_tf = transforms.Compose([
    transforms.Resize(size=(img_width, img_height)),
    transforms.ToTensor(),
])

# Training data loader
train_dataloader = DataLoader(
    cyclegan.ImageDataset(transform=train_tf, unaligned=True),
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)

# Test data loader
val_dataloader = DataLoader(
    cyclegan.ImageDataset(transform=val_tf, unaligned=True, mode="val"),
    batch_size=batch_size,
    shuffle=False,
    num_workers=1,
)

#### Run model training

In [None]:
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Initialize generator and discriminator
G_AB = cyclegan.GeneratorResNet(input_shape, n_residual_blocks)
G_BA = cyclegan.GeneratorResNet(input_shape, n_residual_blocks)
D_A = cyclegan.Discriminator(input_shape)
D_B = cyclegan.Discriminator(input_shape)

G_AB = G_AB.to(device)
G_BA = G_BA.to(device)
D_A = D_A.to(device)
D_B = D_B.to(device)
criterion_GAN.to(device)
criterion_cycle.to(device)
criterion_identity.to(device)


# Initialize weights
G_AB.apply(cyclegan.weights_init_normal)
G_BA.apply(cyclegan.weights_init_normal)
D_A.apply(cyclegan.weights_init_normal)
D_B.apply(cyclegan.weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=cyclegan.LambdaLR(num_epochs, 0, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=cyclegan.LambdaLR(num_epochs, 0, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=cyclegan.LambdaLR(num_epochs, 0, decay_epoch).step
)


# Buffers of previously generated samples
fake_A_buffer = cyclegan.ReplayBuffer()
fake_B_buffer = cyclegan.ReplayBuffer()

In [None]:
def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, f"./models/output/CycleGAN/images/{batches_done}.png", normalize=False)

# ----------
#  Training function
# ----------
def cyclegan_train():
    D_losses = []
    G_losses = []
    adv_losses = []
    cycle_losses = []
    I_losses = []
    prev_time = time.time()
    for epoch in range(num_epochs):
        print('.' * 64)
        print(f"--- Epoch {epoch + 1}/{num_epochs} ---")
        pbar = tqdm(train_dataloader, leave=False)
        agg_D_loss = 0.0
        agg_G_loss = 0.0
        agg_adv_loss = 0.0
        agg_cycle_loss = 0.0
        agg_I_loss = 0.0
        for i, batch in enumerate(pbar):
            pbar.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2
            agg_I_loss += loss_identity.item()

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
            agg_adv_loss += loss_GAN.item()

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
            agg_cycle_loss += loss_cycle.item()

            # Total loss
            loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity
            agg_G_loss += loss_G.item()
            
            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------

            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2
            agg_D_loss += loss_D.item()

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(train_dataloader) + i
            batches_left = num_epochs * len(train_dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()


            # If at sample interval save image
            if batches_done % sample_interval == 0:
                sample_images(batches_done)
                

        D_losses.append(agg_D_loss / (i + 1))
        G_losses.append(agg_G_loss / (i + 1))
        adv_losses.append(agg_adv_loss / (i + 1))
        cycle_losses.append(agg_cycle_loss / (i + 1))
        I_losses.append(agg_I_loss / (i + 1))
        
        # Print log
        print(f"[D loss: {loss_D.item()}] [G loss: {loss_G.item()}, adv: {loss_GAN.item()}, cycle: {loss_cycle.item()}, identity: {loss_identity.item()}] ETA: {time_left}")

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()


    torch.save(G_AB.state_dict(), f"./models/output/CycleGAN/G_AB_trained_{epoch}.pth")
    torch.save(G_BA.state_dict(), f"./models/output/CycleGAN/G_BA_trained_{epoch}.pth")
    torch.save(D_A.state_dict(), f"./models/output/CycleGAN/D_A_trained_{epoch}.pth")
    torch.save(D_B.state_dict(), f"./models/output/CycleGAN/D_B_trained_{epoch}.pth")
            
    return D_losses, G_losses, adv_losses, cycle_losses, I_losses

D_losses, G_losses, adv_losses, cycle_losses, I_losses = cyclegan_train()

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(D_losses)
plt.plot(G_losses)
plt.plot(adv_losses)
plt.plot(cycle_losses)
plt.plot(I_losses)
plt.title('Losses')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Discriminator', 'Generator', 'GAN', 'Cycle', 'Identity'], loc='upper left')

plt.figure()
plt.plot(D_losses)
plt.plot(adv_losses)
plt.plot(cycle_losses)
plt.plot(I_losses)
plt.title('Losses')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Discriminator', 'GAN', 'Cycle', 'Identity'], loc='upper left')

plt.show()

#### Run model evaluation

In [None]:
# Load pretrained model for style transfer
G_AB = cyclegan.GeneratorResNet(input_shape, n_residual_blocks)
G_AB.load_state_dict(torch.load("./models/output/CycleGAN/G_AB_trained_99.pth"))

In [None]:
test_tf = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor(),
])

test_dataloader = DataLoader(
    cyclegan.ImageDatasetTesting(transform=test_tf, mode="val"),
    batch_size=64,
    shuffle=False
)

def cyclegan_eval(G_AB, test_dataloader):
    G_AB.to(device)
    ssim_sum = 0.0
    fid_sum = 0.0
    running_content_loss, running_style_loss = 0.0, 0.0
    total_samples = 0
    with torch.no_grad():
        pbar = tqdm(test_dataloader)
        for i, batch in enumerate(pbar):
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))
            style_labels = batch["label"]
            output = G_AB(real_A)
            
            # Compute quantitative evaluation metrics
            ssim_sum += compute_ssim(real_A, output)
            fid_sum += calculate_fid_from_dataset(real_A, output, device, dims=2048)
            running_content_loss += calc_content_loss(real_A, output)
            running_style_loss += calc_style_loss(real_B, output)

            # Display qualitative evaluation metrics on first batch
            if total_samples == 0:
                plot_results(real_A, real_B, style_labels, output, nrows=5, model_name="CycleGAN")
            total_samples += real_A.size(0)
            
    avg_ssim = ssim_sum / total_samples
    avg_fid = fid_sum / total_samples
    avg_content_loss = running_content_loss / total_samples
    avg_style_loss = running_style_loss / total_samples
    return avg_ssim, avg_fid, avg_content_loss, avg_style_loss

cyclegan_ssim, cyclegan_fid, cyclegan_content_loss, cyclegan_style_loss = cyclegan_eval(G_AB, test_dataloader)

print("--- CycleGAN results ---")
print(f"Average SSIM = {cyclegan_ssim:.7f}")
print(f"Average FID = {cyclegan_fid:.4f}")
print(f"Average content loss = {cyclegan_content_loss:.4f}")
print(f"Average style loss = {cyclegan_style_loss:.4f}")

# 7. <a id="results">Results</a>

# 8. <a id="discussion">Discussion</a>