# ðŸŽ¨ Colorful Image Colorization - Quick Demo

This notebook demonstrates the image colorization pipeline with interactive examples.

## Setup

In [None]:
import sys
import os

# Add parent directory to path
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch

from src.infer import ColorizationInference
from src.models.ops import rgb_to_lab, lab_to_rgb

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Initialize Inference Engine

In [None]:
# Create inference engine
# Note: This uses a randomly initialized model for demonstration
# For real results, provide a trained checkpoint: model_path="checkpoints/best_model.pth"

engine = ColorizationInference(
    model_path=None,  # Use trained checkpoint path here
    model_config={'model_type': 'mobile', 'num_classes': 313, 'base_channels': 32},
    device=None,  # Auto-detect
    use_cache=False
)

print("âœ… Inference engine initialized")

## Load Example Image

In [None]:
# Create a sample image (or load from examples/)
# For real usage, load an actual image:
# img = Image.open('../examples/sample.jpg').convert('RGB')

# Create synthetic image for demo
img_array = np.random.rand(256, 256, 3)
img = Image.fromarray((img_array * 255).astype(np.uint8))

# Convert to grayscale for visualization
img_gray = img.convert('L').convert('RGB')

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img_gray)
axes[0].set_title('Input (Grayscale)')
axes[0].axis('off')

axes[1].imshow(img)
axes[1].set_title('Original Color')
axes[1].axis('off')

plt.tight_layout()
plt.show()

## Colorize with Different Methods

In [None]:
# Colorize with different methods
methods = ['classification', 'opencv']
temperatures = [0.1, 0.38, 0.8]  # For classification method

# Classification with different temperatures
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(np.array(img_gray))
axes[0].set_title('Input')
axes[0].axis('off')

for i, temp in enumerate(temperatures):
    result = engine.colorize_image(
        img,
        method='classification',
        temperature=temp
    )
    
    axes[i+1].imshow(result)
    axes[i+1].set_title(f'T={temp}')
    axes[i+1].axis('off')

plt.suptitle('Classification Method with Different Temperatures', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Create Blend Animation

In [None]:
# Generate animation frames
frames = engine.create_blend_animation(
    img,
    method='classification',
    temperature=0.38,
    num_frames=20
)

print(f"Generated {len(frames)} animation frames")

# Display a few key frames
key_frames = [0, len(frames)//4, len(frames)//2, 3*len(frames)//4, -1]

fig, axes = plt.subplots(1, 5, figsize=(15, 3))

for i, frame_idx in enumerate(key_frames):
    axes[i].imshow(frames[frame_idx])
    axes[i].set_title(f'Frame {frame_idx if frame_idx >= 0 else len(frames) + frame_idx}')
    axes[i].axis('off')

plt.suptitle('Blend Animation Frames', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Export Animation as GIF

In [None]:
# Save as GIF
from PIL import Image

# Convert frames to PIL images
pil_frames = [Image.fromarray((frame * 255).astype(np.uint8)) for frame in frames]

# Save as GIF
output_path = 'colorization_animation.gif'
pil_frames[0].save(
    output_path,
    save_all=True,
    append_images=pil_frames[1:],
    duration=50,  # ms per frame
    loop=0
)

print(f"âœ… Animation saved to {output_path}")

## Color Space Visualization

In [None]:
# Visualize Lab color space
img_np = np.array(img) / 255.0
lab = rgb_to_lab(img_np)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# L channel
axes[0].imshow(lab[:, :, 0], cmap='gray')
axes[0].set_title('L (Lightness)')
axes[0].axis('off')

# a channel
axes[1].imshow(lab[:, :, 1], cmap='RdYlGn')
axes[1].set_title('a (Green-Red)')
axes[1].axis('off')

# b channel
axes[2].imshow(lab[:, :, 2], cmap='YlGnBu')
axes[2].set_title('b (Blue-Yellow)')
axes[2].axis('off')

plt.suptitle('CIE Lab Color Space Decomposition', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:
- âœ… Setting up the inference engine
- âœ… Colorizing images with different methods
- âœ… Effect of temperature on results
- âœ… Creating blend animations
- âœ… Exporting results as GIF
- âœ… Visualizing Lab color space

For production use:
1. Train a model on your dataset (see `configs/quicktrain.yaml`)
2. Load the trained checkpoint in the inference engine
3. Use the Streamlit or Gradio UIs for interactive exploration

See README.md for complete documentation.