# Setup

In [1]:
import os, sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import clear_output
from tqdm import tqdm

project_root = os.path.dirname(os.getcwd())

In [None]:
import tensorflow as tf

# Set GPU memory growth before initializing TensorFlow
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU memory growth enabled for all GPUs.")
    except RuntimeError as e:
        print("Failed to set memory growth:", e)

# Check TensorFlow version
print("TensorFlow version:", tf.__version__)

# Check if GPU is available
print("Is GPU available?", tf.config.list_physical_devices('GPU'))

# List all available devices
print("\nAvailable devices:")
for device in tf.config.list_physical_devices():
    print(device)

# Test if TensorFlow is using the GPU
try:
    with tf.device('/GPU:0'):
        print("\nRunning a simple computation on the GPU...")
        a = tf.constant([[1.0, 2.0], [3.0, 4.0]])
        b = tf.constant([[1.0, 1.0], [0.0, 1.0]])
        result = tf.matmul(a, b)
        print("Matrix multiplication result:\n", result)
except RuntimeError as e:
    print("Error using GPU:", e)


# Load Image

In [7]:
target_size = (256, 256)

In [None]:
# Define a function to load and preprocess images
def load_image(image_path, target_size=(128, 128)):
    """
    Load an image, resize it, and preprocess it for neural network input.
    
    Args:
        image_path (str): Path to the image file.
        target_size (tuple): Target size for the image, e.g., (128, 128).
    
    Returns:
        tf.Tensor: Preprocessed image tensor.
    """
    # Load the image
    image = Image.open(image_path).convert('RGB')  # Ensure 3 channels (RGB)
    
    # Resize the image
    image = image.resize(target_size, Image.LANCZOS)
    
    # Convert to a NumPy array and normalize pixel values
    image_array = np.array(image) / 255.0  # Scale pixel values to [0, 1]
    
    # Convert to a TensorFlow tensor
    image_tensor = tf.convert_to_tensor(image_array, dtype=tf.float32)
    
    # Add a batch dimension for processing in neural networks
    image_tensor = tf.expand_dims(image_tensor, axis=0)  # Shape: (1, 128, 128, 3)
    
    return image_tensor

# Display an image
def display_image(image_tensor, title="Image"):
    """
    Display an image tensor.
    
    Args:
        image_tensor (tf.Tensor): Image tensor with shape (1, height, width, channels).
        title (str): Title of the plot.
    """
    # Remove the batch dimension and clip pixel values
    image = tf.squeeze(image_tensor, axis=0).numpy()  # Shape: (height, width, channels)
    image = np.clip(image, 0, 1)  # Ensure pixel values are in the range [0, 1]
    
    # Plot the image
    plt.imshow(image)
    plt.axis('off')
    plt.title(title)
    plt.show()

# Paths to your images
style_image_path = os.path.join(project_root,"data/optimization_method/starry_night.jpg")  
content_image_path = os.path.join(project_root, "data/optimization_method/wedding_maya.jpg")

# Load and preprocess the images
style_image = load_image(style_image_path, target_size=target_size)
content_image = load_image(content_image_path, target_size=target_size)

# Display the images
display_image(style_image, title="Style Image")
display_image(content_image, title="Content Image")


# Pre-trained VGG

In [None]:
from tensorflow.keras.applications import VGG16

# Load a pre-trained VGG16 model
vgg = VGG16(include_top=False, weights='imagenet', input_shape=target_size + (3,))

for layer in vgg.layers:
    layer.trainable = False # Freeze the layers
    
vgg.summary()

In [31]:
def get_feature_extractor(pre_trained_model: tf.keras.Model, layer_names: list) -> tf.keras.Model:
    """
    Build a feature extractor model that returns intermediate layer outputs.
    
    Args:
        model (tf.keras.Model): The pretrained model to use as a feature extractor.
        layer_names (list): Names of the layers to use for feature extraction.
    
    Returns:
        tf.keras.Model: The feature extractor model.
    """
    # Get the intermediate layer outputs
    outputs = [pre_trained_model.get_layer(name).output for name in layer_names]
    
    # Build the feature extractor model
    return tf.keras.Model(inputs=pre_trained_model.input, outputs=outputs)

# Define the content and style layer names
content_layers = ["block4_conv2"]
style_layers = [ "block1_conv1", "block2_conv1", "block3_conv1", "block4_conv1"]

# Instantiate feature extractor
feature_extractor = get_feature_extractor(vgg, content_layers + style_layers)


# Extract activation for given image
def get_activations(image, feature_extractor: tf.keras.Model) -> tuple:
    """
    Forward pass to get the activations of the content and style layers for an image.
    
    Args:
        image (tf.Tensor): Input image tensor.
        feature_extractor (tf.keras.Model): Feature extractor model.
    """
    
    preprocessed_image = tf.keras.applications.vgg16.preprocess_input(image * 255) # Useful for VGG
    
    # Get activations
    activations = feature_extractor(preprocessed_image)
    content_activations = activations[:len(content_layers)]
    style_activations = activations[len(content_layers):]
    
    del activations
    
    return content_activations, style_activations

# Model training

#### Loss

In [32]:
# Content Loss
def content_loss(content_activation: list[tf.Tensor], target_activation: list[tf.Tensor]) -> tf.Tensor:
    sum = 0
    n_elements = 0
    for c, t in zip(content_activation, target_activation):
        assert c.shape == t.shape, "Activations are different shape."
        sum += tf.math.reduce_sum(tf.math.square(c - t))
        n_elements += tf.size(c)
    
    return sum / tf.cast(n_elements, tf.float32)
    

# Style loss
def gram_matrix(activation: tf.Tensor) -> tf.Tensor:
    # Flatten
    transpose = tf.transpose(activation, perm=[0, 3, 1, 2])
    shape = tf.shape(transpose)
    reshaped = tf.reshape(transpose, [shape[0], shape[1], -1])
    
    return tf.matmul(reshaped, tf.transpose(reshaped, perm=[0, 2, 1]))


def style_loss(style_activations: list[tf.Tensor], target_activations: list[tf.Tensor]) -> tf.Tensor:
    # Gram matrices
    style_grams = [gram_matrix(style) for style in style_activations]
    target_grams = [gram_matrix(target) for target in target_activations]
    
    # MSE
    sum = 0
    n_elements = 0
    
    for m, n in zip(style_grams, target_grams):
        assert m.shape == n.shape, "Gram matrices are different shape."
        sum += tf.math.reduce_sum(tf.math.square(m - n))
        n_elements += tf.size(m)
        
    return sum / tf.cast(n_elements, tf.float32)

# Total variation loss
def total_loss(content_loss, style_loss, weights: dict) -> tf.Tensor:
    return weights['content'] * content_loss + weights['style'] * style_loss

weights = {
	'content' : 1,
	'style' : 1e3
}

#### Monitoring

In [33]:
# Function to save target image
# Function to save target image
def save_image(image_tensor, epoch, method="optimization_method", version="test"):
    """
    Save an image tensor to a file.
    
    Args:
        image_tensor (tf.Tensor): Image tensor with shape (1, height, width, channels).
        epoch (int): Current epoch number.
        method (str): The optimization method used.
        version (str): The version or configuration identifier.
    """
    image = tf.squeeze(image_tensor, axis=0).numpy()  # Remove batch dimension
    image = np.clip(image, 0, 1)
    image = tf.image.convert_image_dtype(image, tf.uint8)
    
    # Define the directory and ensure it exists
    save_dir = os.path.join(project_root, f"models/ouputs_monitoring/{method}/{version}")
    os.makedirs(save_dir, exist_ok=True)  # Create the directory if it doesn't exist
    
    # Construct file path
    file_path = os.path.join(save_dir, f"output_image_{epoch}.png")
    
    # Save the image
    tf.keras.utils.save_img(file_path, image)

#### Training

In [None]:
# Initialization
# Variable
target_image = tf.Variable(tf.identity(content_image), trainable=True) # Copy of content image

# Constants
feature_extractor = get_feature_extractor(vgg, content_layers + style_layers)

content_activations, _ = get_activations(content_image, feature_extractor)
_, style_activations = get_activations(style_image, feature_extractor)

weights = {'content' : 1, 'style' : 1e4} # Loss weights

optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-3, beta_1=0.9, beta_2=0.99, epsilon=1e-7)

#@tf.function
def train_step(target_image, content_activations: list[tf.Tensor], style_activations: list[tf.Tensor], weights: dict, 
               optimizer: tf.keras.optimizers.Optimizer, feature_extractor: tf.keras.Model):
    
    # record forward pass
    with tf.GradientTape() as tape:
        target_content, target_style = get_activations(target_image, feature_extractor)
        
        c_loss = content_loss(content_activations, target_content)
        s_loss = style_loss(style_activations, target_style)
        loss = total_loss(c_loss, s_loss, weights)
        
    # compute gradients
    gradients = tape.gradient(loss, target_image)
    
    # update weights
    optimizer.apply_gradients(zip([gradients], [target_image]))
    
    return loss

# Training loop
n_epochs = 1501
for epoch in tqdm(range(n_epochs), desc="Image optimization"):
    loss = train_step(target_image, content_activations, style_activations, weights, optimizer, feature_extractor)
    
    if epoch % 250 == 0:
        print(f"Epoch {epoch}: Loss: {loss}")
        clipped_image = tf.clip_by_value(target_image, 0., 1.)
        # Save image
        save_image(clipped_image, epoch, method="optimization_method", version="maya")
        
        # Display image
        clear_output(wait=True)
        display_image(clipped_image, title=f"Epoch {epoch}")

In [58]:
def clear_monitoring_files(method="optimization_method", version="test"):
    import shutil
    path = os.path.join(project_root, f"models/ouputs_monitoring/{method}/{version}")
    shutil.rmtree(path)
    os.makedirs(path)
    
clear_monitoring_files(version='maya2')

# Colorimetry

In [None]:
# Load the generated image
gen_im_path = os.path.join(project_root, "models/ouputs_monitoring/optimization_method/maya/output_image_750.png")
generated_image = load_image(gen_im_path, target_size=target_size)

#### Method 1 - Histogram Matching

In [43]:
from skimage.exposure import match_histograms

def match_colors(generated_image, content_image):
    """
    Match the color distribution of the generated image to the content image.

    Args:
        generated_image (np.ndarray): Generated image as a NumPy array.
        content_image (np.ndarray): Content image as a NumPy array.

    Returns:
        np.ndarray: Color-transferred image.
    """
    matched_image = match_histograms(generated_image, content_image, channel_axis=-1)
    return matched_image

In [None]:
matched_image = match_colors(generated_image.numpy(), content_image.numpy())

# Display the images
display_image(content_image, title="Content Image")
display_image(generated_image, title="Generated Image")
display_image(matched_image, title="Color-Transferred Image")

#### Method 2 - Std

In [45]:
def match_mean_std(generated_image, content_image):
    """
    Adjust the mean and standard deviation of the generated image to match the content image.

    Args:
        generated_image (np.ndarray): Generated image as a NumPy array.
        content_image (np.ndarray): Content image as a NumPy array.

    Returns:
        np.ndarray: Color-matched image.
    """
    for channel in range(3):  # Assuming RGB
        gen_mean, gen_std = generated_image[..., channel].mean(), generated_image[..., channel].std()
        cont_mean, cont_std = content_image[..., channel].mean(), content_image[..., channel].std()
        generated_image[..., channel] = (
            (generated_image[..., channel] - gen_mean) / (gen_std + 1e-8)
        ) * cont_std + cont_mean
    return np.clip(generated_image, 0, 1)  # Ensure valid pixel range

In [None]:
matched_image = match_mean_std(generated_image.numpy(), content_image.numpy())

# Display the images
display_image(content_image, title="Content Image")
display_image(generated_image, title="Generated Image")
display_image(matched_image, title="Color-Transferred Image")