# Cartoonify Demo

This notebook demonstrates how to use the Cartoonify tool to transform regular images into cartoon-style artwork.

## Setup
First, let's import the necessary modules and set up our environment.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Add the project root directory to Python path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# Import project modules
from src.cartoonify import cartoonify_image
from src.utils import plot_comparison

# Set up directories
sample_dir = os.path.join('..', 'data', 'sample_images')
output_dir = os.path.join('..', 'data', 'output')

os.makedirs(sample_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

## 1. Load a Sample Image

Let's start by loading one of the sample images. If no sample images exist, we'll create one.

In [None]:
# Find sample images or create a sample
sample_images = [f for f in os.listdir(sample_dir) 
                if os.path.isfile(os.path.join(sample_dir, f)) and 
                f.lower().endswith(('.png', '.jpg', '.jpeg'))]

# Create a sample if none exists
if not sample_images:
    print("No sample images found. Creating a sample image...")
    
    # Create a simple gradient image as a sample
    x = np.linspace(0, 1, 512)
    y = np.linspace(0, 1, 512)
    xx, yy = np.meshgrid(x, y)
    
    # Create RGB channels
    r = xx
    g = yy
    b = xx * yy
    
    # Combine channels
    img = np.stack([r, g, b], axis=2)
    img = (img * 255).astype(np.uint8)
    
    # Save the sample image
    sample_path = os.path.join(sample_dir, 'sample_gradient.jpg')
    Image.fromarray(img).save(sample_path)
    
    sample_images = ['sample_gradient.jpg']
    print(f"Created sample image at {sample_path}")

# Select the first sample image
sample_image_path = os.path.join(sample_dir, sample_images[0])
print(f"Using sample image: {sample_image_path}")

# Display the image
plt.figure(figsize=(8, 8))
plt.imshow(np.array(Image.open(sample_image_path)))
plt.title("Sample Image")
plt.axis('off')
plt.show()

## 2. Cartoonify the Image

Now, let's transform the image into a cartoon style.

In [None]:
# Define output path
output_image_path = os.path.join(output_dir, f"cartoon_{os.path.basename(sample_image_path)}")

try:
    # Apply cartoonification
    print("Applying cartoonification...")
    cartoonify_image(sample_image_path, output_image_path)
    print(f"Cartoonified image saved to {output_image_path}")
    
    # Compare original and cartoonified images
    fig = plot_comparison(sample_image_path, output_image_path, figsize=(15, 7))
    plt.show()
except Exception as e:
    print(f"Error during cartoonification: {e}")
    print("\nNote: If you're seeing a model not found error, you need to train the model first.")
    print("Run 'python train.py' from the project root directory to train the model.")

## 3. Process Multiple Images

Let's see how we can process multiple images at once.

In [None]:
# This code creates a few more sample images if needed
if len(sample_images) < 3:
    print("Creating additional sample images...")
    
    # Create samples with different patterns
    for i in range(2):
        # Create a different pattern for each sample
        x = np.linspace(0, 4*np.pi, 512)
        y = np.linspace(0, 4*np.pi, 512)
        xx, yy = np.meshgrid(x, y)
        
        # Different patterns
        if i == 0:
            # Sine wave pattern
            r = 0.5 + 0.5 * np.sin(xx)
            g = 0.5 + 0.5 * np.sin(yy)
            b = 0.5 + 0.5 * np.sin(xx + yy)
        else:
            # Radial pattern
            center_x, center_y = 2*np.pi, 2*np.pi
            dist = np.sqrt((xx - center_x)**2 + (yy - center_y)**2)
            r = 0.5 + 0.5 * np.sin(dist)
            g = 0.5 + 0.5 * np.cos(dist)
            b = 0.5 + 0.5 * np.sin(2 * dist)
        
        # Combine channels
        img = np.stack([r, g, b], axis=2)
        img = (img * 255).astype(np.uint8)
        
        # Save the sample image
        sample_path = os.path.join(sample_dir, f'sample_pattern_{i+1}.jpg')
        Image.fromarray(img).save(sample_path)
        sample_images.append(f'sample_pattern_{i+1}.jpg')
        print(f"Created sample image at {sample_path}")

# Get all sample images
sample_images = [f for f in os.listdir(sample_dir) 
                if os.path.isfile(os.path.join(sample_dir, f)) and 
                f.lower().endswith(('.png', '.jpg', '.jpeg'))]

# Process each sample image
for img_file in sample_images[:3]:  # Process up to 3 images
    input_path = os.path.join(sample_dir, img_file)
    output_path = os.path.join(output_dir, f"cartoon_{img_file}")
    
    try:
        print(f"Processing {img_file}...")
        cartoonify_image(input_path, output_path)
        print(f"Cartoonified {img_file} saved to {output_path}")
        
        # Show comparison
        fig = plot_comparison(input_path, output_path, figsize=(12, 6))
        plt.show()
    except Exception as e:
        print(f"Error processing {img_file}: {e}")

## Conclusion

This demo demonstrates the basic functionality of the Cartoonify tool. You can:

1. Process individual images
2. Process multiple images
3. Compare original and cartoonified versions

For more advanced options and customization, please check the detailed documentation in the `docs/usage.md` file.