# Multiple-Style Blending Style Transfer

## LIBRARY IMPORTS AND SETUP

In [1]:
# Neural Style Transfer with Multiple Style Blending
# This code implements an advanced neural style transfer system that can blend two style images
# with a content image using deep learning techniques and VGG19 neural network

# Import libraries
# Deep learning and numerical operations
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance
import time

# IPython widgets for interactive UI
from ipywidgets import (VBox, HBox, FloatSlider, IntSlider, FileUpload, Button, Output, Label, 
                        Layout, Image as ipyImage, Dropdown, Checkbox)
from IPython.display import clear_output, display
import io
import math

# Configure TensorFlow Hub to use compressed model format for faster loading
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'

## USER INTERFACE WIDGET CREATION

In [2]:
# File upload widgets for content and style images
content_upload = FileUpload(description="Upload Content", accept='image/*')
style_upload1 = FileUpload(description="Upload Style 1", accept='image/*')
style_upload2 = FileUpload(description="Upload Style 2", accept='image/*')

# Output widgets to display image previews
content_preview = Output()
style_preview1 = Output()
style_preview2 = Output()

# Style weight sliders
# Create sliders for adjust style weight
# The weights are constrained to sum to 1.0 for proper blending
style_slider1 = FloatSlider(
    min=0, max=1, value=0.5, step=0.01,
    description='Style 1 Weight',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

style_slider2 = FloatSlider(
    min=0, max=1, value=0.5, step=0.01,
    description='Style 2 Weight',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Let the sliders must sum to 1.0
def update_slider1(change):
    style_slider2.value = 1.0 - change['new']

def update_slider2(change):
    style_slider1.value = 1.0 - change['new']

# Attach observers to sliders for automatic synchronization
style_slider1.observe(update_slider1, names='value')
style_slider2.observe(update_slider2, names='value')

# Loss weight sliders
# Content weight controls how much the output resembles the original content structure
content_weight_slider = FloatSlider(
    min=0, max=1, value=1.0, step=0.01,
    description='Content Weight (1e4)', # Actual weight is value * 1e4
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Style weight controls how much the output adopts the artistic style
style_weight_slider = FloatSlider(
    min=0, max=1, value=1.0, step=0.01,
    description='Style Weight (1e-2)', # Actual weight is value * 1e-2
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Total variation weight helps reduce noise and create smoother results
tv_weight_slider = FloatSlider(
    min=0, max=1, value=1.0, step=0.01,
    description='Total Variation (30)',  # Actual weight is value * 30
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Training parameter sliders
# Epoch
epoch_slider = IntSlider(
    min=1, max=50, value=10,
    description='Epochs',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Step per epoch
steps_slider = IntSlider(
    min=10, max=200, value=50,
    description='Steps/Epoch',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Color preservation slider
# Helps maintain the original content's color scheme
color_preserve_slider = FloatSlider(
    min=0, max=1, value=0.0, step=0.01,
    description='Color Preserve',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Output format selection
output_quality_dropdown = Dropdown(
    options=['JPEG (Default)', 'PNG (Lossless)'],
    value='JPEG (Default)',
    description='Format:'
)

# Style morphing creates a gradual transition between styles over epochs
style_morphing_toggle = Checkbox(
    description='Enable Style Morphing (Style 1 = 0.0 -> 1.0, Style 2 = 1.0 -> 0.0)',
    value=False,
    layout={'width': '500px'}
)

# Image enhancement sliders for post-processing
brightness_slider = FloatSlider(
    min=0.5, max=1.5, value=1.0, step=0.01,
    description='Brightness',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

contrast_slider = FloatSlider(
    min=0.5, max=1.5, value=1.0, step=0.01,
    description='Contrast',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

saturation_slider = FloatSlider(
    min=0.5, max=1.5, value=1.0, step=0.01,
    description='Saturation',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Style scale affects the resolution at which style features are extracted
style_scale_slider = FloatSlider(
    min=0.1, max=2.0, value=1.0, step=0.1,
    description='Style Scale',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Output resolution slider
resolution_slider = IntSlider(
    min=256, max=2048, step=256, value=512,
    description='Output Resolution',
    style={'description_width': 'initial'},
    layout={'width': '400px'}
)

# Control buttons
train_btn = Button(description="Start Training", button_style='success')
reset_btn = Button(description="Reset All", button_style='danger') 

# Output areas for displaying results and progress
output_area = Output() # Training progress and messages
result_area = Output() # Final results display

# Real-time preview widget shows intermediate results during training
live_preview = ipyImage(
    layout=Layout(width='400px', height='auto')
)

# Create section headers
upload_header = Label(value="📁 Image Upload: ", style={'font_weight': 'bold', 'font_size': '16px'})
style_weights_header = Label(value="⚖️ Style Weights Adjustment: ", style={'font_weight': 'bold', 'font_size': '16px'})
training_header = Label(value="⚙️ Training Parameters: ", style={'font_weight': 'bold', 'font_size': '16px'})
advanced_header = Label(value="🎚️ Advanced Loss Weights:", style={'font_weight': 'bold', 'font_size': '16px'})
color_adjustment_header = Label(value="🎨 Color Adjustment:", style={'font_weight': 'bold', 'font_size': '16px'})
output_setting_header = Label(value="🖥️ Output Setting:", style={'font_weight': 'bold', 'font_size': '16px'})
results_header = Label(value="🖼️ Results: ", style={'font_weight': 'bold', 'font_size': '16px'})

## USER INTERFACE LAYOUT

In [3]:
# UI Layout
ui = VBox([
    # Image upload section
    VBox([
        upload_header,
        HBox([
            VBox([Label("Content Image:"), content_upload, content_preview]),
            VBox([Label("Style Image 1:"), style_upload1, style_preview1]),
            VBox([Label("Style Image 2:"), style_upload2, style_preview2])
        ])
    ], layout=Layout(margin='0 0 20px 0')),

    # Style weight controls section
    VBox([
        style_weights_header,
        HBox([style_slider1, style_slider2], layout={'width': '100%'}),
        HBox([style_morphing_toggle])
    ], layout=Layout(margin='0 0 20px 0')),

    # Training parameters section
    VBox([
        training_header,
        HBox([epoch_slider, steps_slider], layout={'width': '100%'})
    ], layout=Layout(margin='0 0 20px 0')),

    # Advanced loss weight controls section
    VBox([
        advanced_header,
        HBox([content_weight_slider, style_weight_slider]),
        HBox([tv_weight_slider, color_preserve_slider]),
    ], layout=Layout(margin='0 0 20px 0')),

    # Color and style adjustment section
    VBox([
        color_adjustment_header,
        HBox([brightness_slider, contrast_slider]),
        HBox([saturation_slider, style_scale_slider]),
    ], layout=Layout(margin='0 0 20px 0')),

    # Output settings section
    VBox([
        output_setting_header,
        HBox([resolution_slider, output_quality_dropdown])
    ], layout=Layout(margin='0 0 20px 0')),

    # Control buttons
    HBox([train_btn, reset_btn]), 

    # Results display section
    VBox([
        results_header,
        output_area,
        live_preview,
        result_area
    ])
])

## IMAGE PREVIEW FUNCTIONS

In [4]:
# Create a callback function to update the image preview when the file is uploaded
def update_preview(upload_widget, preview_area, max_size=(200, 200)):
    def callback(change):
        with preview_area:
            clear_output() # Clear previous preview
            if upload_widget.value:
                try:
                    # Extract the uploaded file data
                    upload_data = upload_widget.value[0]
                    content = upload_data['content']

                    # Open and resize the image for preview
                    img = Image.open(io.BytesIO(content))
                    img.thumbnail(max_size) # Resize while maintaining aspect ratio
                    display(img)
                except Exception as e:
                    print(f"Preview error: {str(e)}")
                    
    # Attach the callback to the upload widget
    upload_widget.observe(callback, names='value')

# Set up preview callbacks for all upload widgets
update_preview(content_upload, content_preview)
update_preview(style_upload1, style_preview1)
update_preview(style_upload2, style_preview2)

## IMAGE PROCESSING FUNCTIONS

In [5]:
# Extract and process an uploaded image file
def process_upload(upload_widget):
    if upload_widget.value:
        upload_data = upload_widget.value[0]
        content = upload_data['content']
        # Convert to RGB to ensure consistent color format
        img = Image.open(io.BytesIO(content)).convert("RGB")
        return np.array(img)
    raise ValueError("No file uploaded")

In [6]:
# Preprocess an image for neural style transfer
def load_and_preprocess(img_array, max_dim=512):
    # Convert to float32 and normalize to [0,1] range
    img = tf.image.convert_image_dtype(img_array, tf.float32)

    # Calculate scaling factor to fit within max_dim while preserving aspect ratio
    shape = tf.cast(tf.shape(img)[:-1], tf.float32) # Get height and width
    scale = max_dim / tf.reduce_max(shape) # Scale factor based on larger dimension
    new_shape = tf.cast(shape * scale, tf.int32)

    # Resize and add batch dimension
    return tf.image.resize(img, new_shape)[tf.newaxis, :]

In [7]:
# Adjust the color channels of a stylized image to preserve the original content's colors
# Use the LAB color space for more uniform color processing
def color_adjust(image_np, reference_np, weight):
    # Convert numpy arrays to PIL Images
    result_img = Image.fromarray(image_np.astype(np.uint8))
    content_img = Image.fromarray(reference_np.astype(np.uint8))

    # Ensure both images have the same size
    if result_img.size != content_img.size:
        content_img = content_img.resize(result_img.size, Image.Resampling.LANCZOS)

    # Convert to LAB color space for better color manipulation
    # LAB separates luminance (L) from color information (A, B)
    result_lab = np.array(result_img.convert('LAB')).astype(np.float32)
    content_lab = np.array(content_img.convert('LAB')).astype(np.float32)

    # Blend the A and B channels (color) while preserving L channel (luminance)
    result_lab[:,:,1] = result_lab[:,:,1] * (1 - weight) + content_lab[:,:,1] * weight # A channel
    result_lab[:,:,2] = result_lab[:,:,2] * (1 - weight) + content_lab[:,:,2] * weight # B channel

    # Clamp values and convert back to RGB
    result_lab = np.clip(result_lab, 0, 255).astype(np.uint8)
    return np.array(Image.fromarray(result_lab, 'LAB').convert('RGB'))

In [8]:
# Apply brightness, contrast, and saturation enhancements to image
def apply_enhancements(image_np, brightness, contrast, saturation):
    img = Image.fromarray(image_np.astype(np.uint8))

    # Apply enhancements sequentially using PIL's ImageEnhance
    img = ImageEnhance.Brightness(img).enhance(brightness)
    img = ImageEnhance.Contrast(img).enhance(contrast)
    img = ImageEnhance.Color(img).enhance(saturation)
    
    return np.array(img)

## NEURAL STYLE TRANSFER CORE CLASSES

In [9]:
# Style Transfer, Keras model to extracts both style and content features from VGG19
class StyleContentModel(tf.keras.models.Model):
    # Initialize the style-content extraction model
    def __init__(self, vgg, style_layers, content_layers):
        super().__init__()
        # Create a model that outputs activations from specified layers
        self.vgg = tf.keras.Model([vgg.input],
                                 [vgg.get_layer(name).output for name in style_layers + content_layers])

        # Store layer information
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.num_style_layers = len(style_layers)

    def call(self, inputs):
        # Preprocess inputs for VGG19 (scale to [0,255] and apply VGG preprocessing)
        inputs = inputs * 255.0
        preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)

        # Extract features from all specified layers
        outputs = self.vgg(preprocessed_input)

        # Split outputs into style and content features
        style_outputs, content_outputs = (outputs[:self.num_style_layers], outputs[self.num_style_layers:])
        style_outputs = [self.gram_matrix(style_output) for style_output in style_outputs]
       
        return {
            'style': {name: value for name, value in zip(self.style_layers, style_outputs)},
            'content': {name: value for name, value in zip(self.content_layers, content_outputs)}
        }

    # Compute Gram matrix for style representation
    def gram_matrix(self, input_tensor):
        # Compute correlations between feature maps using Einstein summation
        result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)

        # Normalize by the size of the feature maps
        input_shape = tf.shape(input_tensor)
        return result / tf.cast(input_shape[1]*input_shape[2], tf.float32)

In [10]:
# The main NST system that coordinates the entire process
class StyleTransferSystem:
    def __init__(self):
        # Initialize the style transfer system with predefined layer selections
        # Deep layers that capture semantic content
        self.content_layers = ['block5_conv2']

        # Multiple layers at different scales for rich style representation
        self.style_layers = ['block1_conv1', # Low-level features (edges, textures)
                             'block2_conv1', # Mid-level features
                             'block3_conv1', # Higher-level patterns
                             'block4_conv1', # Complex patterns
                             'block5_conv1'] # High-level features

        # Build the feature extraction model
        self.extractor = self.build_extractor()

    # Build the feature extraction model using pre-trained VGG19
    def build_extractor(self):
        # Load pre-trained VGG19 without the classification head
        vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
        vgg.trainable = False
        return StyleContentModel(vgg, self.style_layers, self.content_layers)

    # Execute the neural style transfer optimization process
    def run_transfer(self, content_img, style_imgs, weights_list, epochs, steps, content_weight, style_weight, tv_weight):
        # Extract target features from style and content images
        style_targets = [self.extractor(style_img)['style'] for style_img in style_imgs]
        content_targets = self.extractor(content_img)['content']

        # Initialize the generated image as a copy of the content image
        image = tf.Variable(content_img)
        
        # Use Adam optimizer for stable convergence
        opt = tf.keras.optimizers.Adam(0.02)
    
        # Clear the live preview and result area before starting
        live_preview.value = b''
        with result_area:
            clear_output()
            
        with output_area:
            print("Live Preview:")

        # Optimization loop
        for epoch in range(epochs):
            start_time = time.time()

            # Get style weights for current epoch (enables style morphing)
            current_weights = weights_list[epoch]

            # Blend multiple style targets according to current weights
            blended_style = {
                layer: sum(current_weights[i] * style[layer] for i, style in enumerate(style_targets))
                for layer in self.style_layers
            }

            # Optimization steps within each epoch
            for step in range(steps):
                # Compute gradients using automatic differentiation
                with tf.GradientTape() as tape:
                    # Extract features from current generated image
                    outputs = self.extractor(image)

                    # Minimize difference between generated and target style features
                    style_loss = tf.add_n([
                        tf.reduce_mean((outputs['style'][layer] - blended_style[layer])**2)
                        for layer in self.style_layers]) * style_weight * 1e-2

                    # Preserve semantic content from original image
                    content_loss = tf.add_n([
                        tf.reduce_mean((outputs['content'][layer] - content_targets[layer])**2)
                        for layer in self.content_layers]) * content_weight * 1e4

                    # Encourage spatial smoothness (reduce noise)
                    tv_loss = tf.image.total_variation(image) * tv_weight * 30

                    # Combined loss function
                    total_loss = style_loss + content_loss + tv_loss

                # Apply gradients to update the generated image
                grad = tape.gradient(total_loss, image)
                opt.apply_gradients([(grad, image)])

                # Clamp pixel values to valid range [0,1]
                image.assign(tf.clip_by_value(image, 0.0, 1.0))

                # Display progress information
                with output_area:
                    clear_output(wait=True)
                    print(f"Epoch {epoch+1}/{epochs} - Step {step+1}/{steps}")
                    print(f"Style Loss: {style_loss.numpy():.2f} | Content Loss: {content_loss.numpy():.2f}")
                    # print(f"Epoch {epoch+1} completed in {time.time() - start_time:.1f} seconds") # Testing Purpose

            # Update live preview at the end of each epoch
            current_result_pil = Image.fromarray(np.clip(image.numpy()[0] * 255, 0, 255).astype(np.uint8))

            # Convert to JPEG for display
            buf = io.BytesIO()
            current_result_pil.save(buf, format='JPEG')
            live_preview.value = buf.getvalue()
            
        return image.numpy()

## UI EVENT HANDLERS

In [11]:
def on_reset_click(btn):
    # Reset all sliders to their initial values
    style_slider1.value = 0.5
    style_slider2.value = 0.5
    content_weight_slider.value = 1.0
    style_weight_slider.value = 1.0
    tv_weight_slider.value = 1.0
    color_preserve_slider.value = 0.0
    epoch_slider.value = 10
    steps_slider.value = 50
    style_morphing_toggle.value = False
    brightness_slider.value = 1.0
    contrast_slider.value = 1.0
    saturation_slider.value = 1.0
    style_scale_slider.value = 1.0
    resolution_slider.value = 512
    output_quality_dropdown.value = 'JPEG (Default)'

    print("Settings and outputs have been reset.")

In [12]:
# UI Event Handlers
def on_train_click(btn):
    with output_area:
        clear_output()
        print("Starting style transfer...")

    try:
        # Validate that all required images are uploaded
        if not (content_upload.value and style_upload1.value and style_upload2.value):
            raise ValueError("Please upload all required images")

        # Process uploaded images
        content_img = process_upload(content_upload)
        style_img1 = process_upload(style_upload1)
        style_img2 = process_upload(style_upload2)

        # Extract parameter values from UI controls
        content_w = content_weight_slider.value
        style_w = style_weight_slider.value
        tv_w = tv_weight_slider.value
        color_w = color_preserve_slider.value
        brightness = brightness_slider.value
        contrast = contrast_slider.value
        saturation = saturation_slider.value
        style_scale = style_scale_slider.value
        resolution = resolution_slider.value

        # Determine style blending weights for each epoch
        if style_morphing_toggle.value:
            # Gradually transition from style 1 to style 2
            total_epochs = epoch_slider.value
            weights = []
            for i in range(total_epochs):
                morph_factor = i / (total_epochs - 1) if total_epochs > 1 else 0.0
                w1 = 1.0 - morph_factor # Decreases from 0.0 to 1.0
                w2 = morph_factor # Increases from 1.0 to 0.0
                weights.append([w1, w2])
        else:
            # Training based on weight sliders
            w1 = style_slider1.value
            w2 = style_slider2.value
            weights = [[w1, w2]] * epoch_slider.value

        # Preprocess images for neural network input
        content_tensor = load_and_preprocess(content_img, max_dim=resolution)
        style_tensors = [
            load_and_preprocess(style_img1, max_dim=int(resolution * style_scale)),
            load_and_preprocess(style_img2, max_dim=int(resolution * style_scale))
        ]

        # Display input images for reference
        with output_area:
            clear_output()
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(content_img)
            axes[0].set_title('Content Image')
            axes[1].imshow(style_img1)
            axes[1].set_title(f'Style 1 (Weight: {w1:.0%})')
            axes[2].imshow(style_img2)
            axes[2].set_title(f'Style 2 (Weight: {w2:.0%})')
            plt.show()
            print("Starting style transfer...")

        # Execute the neural style transfer
        sts = StyleTransferSystem()
        result = sts.run_transfer(content_tensor, style_tensors, weights,
                                 epoch_slider.value, steps_slider.value,
                                 content_w, style_w, tv_w)

        # Clear the live preview after training is finished
        live_preview.value = b''

        # Convert result to displayable format
        result_img = np.clip(result[0] * 255, 0, 255).astype(np.uint8)

        # Apply color preservation if requested
        if color_w > 0:
            # Resize content image to match result dimensions
            content_resized = tf.image.resize(content_img[tf.newaxis, :],
                                             [result_img.shape[0], result_img.shape[1]])
            content_resized = tf.clip_by_value(content_resized[0], 0, 255).numpy().astype(np.uint8)
            stylized_img = color_adjust(result_img, content_resized, color_w)
        else:
            stylized_img = result_img

        # Apply final enhancements (brightness, contrast, saturation)
        final_img = apply_enhancements(stylized_img, brightness, contrast, saturation)

        # Display final results
        with result_area:
            clear_output()
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            ax1.imshow(result_img)
            ax1.set_title('Stylized Result (Raw)')
            ax2.imshow(final_img)
            ax2.set_title('Color-Preserved & Enhanced Result')
            plt.show()

            # Save results to files
            file_format = "png" if output_quality_dropdown.value == 'PNG (Lossless)' else "jpg"
            Image.fromarray(result_img).save(f"result_raw.{file_format}")
            Image.fromarray(final_img).save(f"result_final.{file_format}")
            print(f"Results saved as result_raw.{file_format} and result_final.{file_format}")

    except Exception as e:
        # Handle and display any errors that occur during processing
        with output_area:
            clear_output()
            print(f"Error: {str(e)}")

## UI INITIALIZATION AND STARTUP

In [13]:
# Start UI
train_btn.on_click(on_train_click) # Start training when the train button is clicked
reset_btn.on_click(on_reset_click) # Reset all settings when the reset button is clicked

# Display the complete user interface
display(ui)

VBox(children=(VBox(children=(Label(value='📁 Image Upload: ', style=LabelStyle(font_size='16px', font_weight='…