<center>
    <h1>Projet DLA</h1>
    <h2>Style Transfer Using Convolutional Neural Network</h2> 
</center>

## Import modules 

In [None]:
import torch
import torchvision
from torchvision import transforms
from torchvision import models
from torchvision.models import vgg19, VGG19_Weights
import torch.nn.functional as F
import torch.nn as nn

import torch.optim as optim
import copy

from PIL import Image
import matplotlib.pyplot as plt

## Load and display images 

In [None]:
# use cuda if it is available for GPU training otherwise it will use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# shape of the output image
imshape = (224, 224)

def image_loader(image_name):
    # scale imported image
    # transform it into a torch tensor
    loader = transforms.Compose([transforms.Resize(imshape),  transforms.ToTensor()])

    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)   # add an additional dimension for fake batch (here 1)
    return image.to(device, torch.float) # move the image tensor to the correct device

def image_display(tensor, title=None):
    unloader = transforms.ToPILImage()  # reconvert into PIL image
    image = tensor.cpu().clone()        # clone the tensor
    image = unloader(image.squeeze(0))  # remove the fake batch dimension
    plt.show()
    plt.imshow(image)
    if title is not None:
        plt.title(title)

In [None]:
# shape of the output image
imshape = (224, 224)

image_path = "../data/"
content_image_name = "content.jpeg"
style_image_name = "style.jpeg"

content_image = image_loader(image_path + content_image_name)
style_image = image_loader(image_path + style_image_name)

content_height, content_width = content_image.shape[2], content_image.shape[3]
style_height, style_width = style_image.shape[2], style_image.shape[3]

print(f"Content image shape : {content_height} x {content_width}")
print(f"Style image shape : {style_height} x {style_width}")
print(content_image.size())
image_display(content_image, "Content image")
image_display(style_image, "Style image")

### Load the VGG-19 pretrained model 

In [None]:
# importing the VGG 19 model with pre-trained weights
model = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).to(device) # move the model to the correct device

<h2>1. Choice of the layers for the content representation</h2>

In [None]:
content_layers_test = ["conv1_2","conv2_2","conv3_2","conv4_2","conv5_2"]

In [None]:
blocks = [2, 2, 4, 4, 4]  # Number of convolutional layers in each block of the VGG-19 model
renamed_model = nn.Sequential()

# Renommer les couches
index_conv = 0
index_relu = 0
current_block = 0
i = 0

for layer in model.features.eval().children():
    # Vérifier si on doit passer au bloc suivant
    if current_block < len(blocks) and index_conv == blocks[current_block]:
        index_conv = 0  # Réinitialiser le compteur pour le nouveau bloc
        current_block += 1

    if isinstance(layer, nn.Conv2d):  # Pour les couches convolutionnelles
        index_conv += 1
        name = f'conv{current_block + 1}_{index_conv}'
        renamed_model.add_module(name, layer)

    elif isinstance(layer, nn.ReLU):  # Pour les couches ReLU
        index_relu += 1
        name = f'relu{current_block + 1}_{index_relu}'
        renamed_model.add_module(name, nn.ReLU(inplace=False))  # Utiliser inplace=False pour compatibilité

    elif isinstance(layer, nn.MaxPool2d):  # Pour les couches MaxPooling
        name = f'pool{current_block + 1}'
        renamed_model.add_module(name, layer)
    i += 1
# Display the name of the layers
print(renamed_model)

In [None]:
class VGGActivations_content(nn.Module):
    """
    Extracts activations from specific layers of a model for content reconstruction.

    Attributes:
    ------------
    model : nn.Module
        The pretrained model (e.g., VGG) used to extract activations.
    target_layers : list of str
        The names of the layers from which to extract activations.
    layer_outputs : dict
        A dictionary that stores the activations for the target layers.
    """

    def __init__(self, model, target_layers):
        """
        Initializes the class with the given model and the list of target layers.

        Parameters:
        ------------
        model : nn.Module
            The pretrained model used to extract activations.
        target_layers : list of str
            The names of the layers from which to extract activations.
        """
        super(VGGActivations_content, self).__init__()
        self.model = model
        self.target_layers = target_layers
        self.layer_outputs = {}

    def forward(self, x):
        """
        Passes input through the model and extracts activations for the target layers.

        Parameters:
        ------------
        x : torch.Tensor
            The input image tensor.

        Returns:
        ------------
        dict
            A dictionary containing activations for the specified target layers.
        """
        self.layer_outputs = {}  # Reset the output dictionary
        for name, layer in self.model.named_children():
            x = layer(x)  # Pass the input through each layer
            if name in self.target_layers:
                self.layer_outputs[name] = x  # Store activations for target layers
        return self.layer_outputs


def reconstruct_image_content(activations_dict):
    """
    Reconstructs an image from the activations of specific layers using optimization.

    Parameters:
    ------------
    activations_dict : dict
        A dictionary containing activations for the target layers.

    Returns:
    ------------
    dict
        A dictionary mapping each layer to the reconstructed image.
    """
    # Dictionary to store reconstructed images for each layer
    reconstructed_images = {}

    # Iterate over each target layer and its corresponding activation
    for layer, activation in activations_dict.items():
        print(f"{'-'*10} Running optimization for layer: {layer} {'-'*10}")

        # Initialize a random image for optimization
        reconstructed_image = torch.rand_like(input_image, requires_grad=True)

        # Set up an optimizer to update the random image
        optimizer = torch.optim.Adam([reconstructed_image], lr=0.01)

        # Perform optimization to reconstruct the image
        for step in range(3000):
            optimizer.zero_grad()  # Reset gradients

            # Compute activations for the current reconstructed image
            generated_activations = vgg_activations_content(reconstructed_image)

            # Compute the loss between the generated and target activations
            loss = torch.nn.functional.mse_loss(generated_activations[layer], activations_dict[layer])

            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the image based on gradients

            # Log progress every 50 steps
            if step % 50 == 0:
                print(f"Step {step} - Loss: {loss.item()}")

        # Store the optimized image for the current layer
        reconstructed_images[layer] = reconstructed_image

    return reconstructed_images


def deprocess(tensor):
    """
    Dénormalizes an image tensor for visualization.

    Parameters:
    ------------
    tensor : torch.Tensor
        A normalized image tensor (e.g., in the range of -1 to 1).

    Returns:
    ------------
    torch.Tensor
        A denormalized image tensor (in the range of 0 to 1, clamped).
    """
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)  # Mean for ImageNet normalization
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)    # Std for ImageNet normalization
    tensor = tensor * std + mean  # Denormalize
    return tensor.clamp(0, 1)  # Clamp values to ensure valid range for visualization


In [None]:
# Prétraitement pour les images VGG
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

input_image = preprocess(torchvision.transforms.functional.to_pil_image(content_image.squeeze(0))).unsqueeze(0)

# Instantiate the new model 
vgg_activations_content = VGGActivations_content(renamed_model, content_layers_test)

# Get the activations at each layer of the list
with torch.no_grad():
    activations_content = vgg_activations_content(input_image)

reconstructed_images_content = reconstruct_image_content(activations_dict = activations_content)

for layer, img in reconstructed_images_content.items():
    
    # Deprocess the image
    output_image = deprocess(img.detach())
    
    # Display the image
    image_display(output_image, layer)

<h2>2. Justifying the choice of the layers for style reconstruction</h2>

In [None]:
style_layers_test = ["conv1_1","conv2_1","conv3_1","conv4_1","conv4_1"]

In [None]:
class VGGActivationsStyle(nn.Module):
    """
    Extracts activations from specific layers of a VGG model and computes their Gram matrices 
    for style transfer tasks, returning a structured dictionary.
    """
    def __init__(self, model, target_layers):
        """
        Initializes the class with the given model and the list of target layers.
        """
        super(VGGActivationsStyle, self).__init__()
        self.model = model
        self.target_layers = target_layers
        self.layer_outputs = {}

    def gram_matrix(self, activation):
        """
        Computes the Gram matrix of the activation to capture style information.
        """
        a, b, c, d = activation.size()  
        features = activation.view(a * b, c * d)  # Flatten spatial dimensions
        G = torch.mm(features, features.t())  # Compute the dot product of feature maps
        return G.div(a * b * c * d)  # Normalize by the total number of elements

    def forward(self, x):
        """
        Passes input through the model to compute and store the Gram matrices 
        for the target layers, building a structured dictionary.
        """
        self.layer_outputs = {}
        model = nn.Sequential()
        cumulative_grams = []  # List to store cumulative Gram matrices

        for name, layer in self.model.named_children():
            x = layer(x)  # Forward pass through the layer
            if name in self.target_layers:
                gram = self.gram_matrix(x)
                cumulative_grams.append(gram)  # Add current Gram matrix
                self.layer_outputs[name] = cumulative_grams.copy()  # Store copy of cumulative list

        return self.layer_outputs



def reconstruct_image_style(activations_dict):
    """
    Reconstructs an image from the Gram matrices of activations using optimization.

    Parameters:
    ------------
        activations_dict (dict): A dictionary containing the Gram matrices of activations for target layers.

    Returns:
    ------------
        dict
        A dictionary mapping each layer combination to the reconstructed image.
    """

    reconstructed_images = {}

    for layer in style_layers_test:
        print(f"{'-'*10} Running optimization for model up to layer: {layer} {'-'*10}")

        # Initialize a random image to optimize
        reconstructed_image = torch.rand_like(input_image, requires_grad=True)

        # Set up an optimizer for the image
        optimizer = torch.optim.Adam([reconstructed_image], lr=0.01)

        # Optimization loop
        for step in range(3000):
            optimizer.zero_grad()  # Reset gradients
            loss = 0

            # Compute activations for the reconstructed image
            generated_activations = vgg_activationsStyle(reconstructed_image)

            # Calculate loss for all target layers up to the current one
            for prev_layer in style_layers_test[: style_layers_test.index(layer) + 1]:
                target_gram = activations_dict[prev_layer][style_layers_test.index(prev_layer)]
                generated_gram = generated_activations[prev_layer][style_layers_test.index(prev_layer)]
                loss += torch.nn.functional.mse_loss(generated_gram, target_gram)

            loss /= len(style_layers_test[: style_layers_test.index(layer) + 1])  # Normalize loss

            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the image

            if step % 50 == 0:
                print(f"Step {step}, Loss: {loss.item()}")

        reconstructed_images[layer] = reconstructed_image

    return reconstructed_images



In [None]:
# Prétraitement pour les images VGG
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

input_image = preprocess(torchvision.transforms.functional.to_pil_image(style_image.squeeze(0))).unsqueeze(0)

# Instantiate the new model 
vgg_activations_style = VGGActivationsStyle(renamed_model, style_layers_test)

# Get the activations at each layer of the list
with torch.no_grad():
    activations_style = vgg_activations_style(input_image)

reconstructed_images_style = reconstruct_image_style(activations_dict = activations_style)

for layer, img in reconstructed_images_style.items():
    
    # Deprocess the image
    output_image = deprocess(img.detach())
    
    # Display the image
    image_display(output_image, layer)