Image Inpainting with Stable Diffusion
=====================================

[View on Google Colab](https://colab.research.google.com/drive/1H0hDErdyI_Qg5dpULRls8a8-3tOyhqMB?usp=sharing)

### Import the necessary libraries

In [23]:
# !pip install torch matplotlib pillow diffusers
# !pip install "numpy<2.0.0"

# Core imports
import torch
import numpy as np
from typing import List, Tuple, Optional
import matplotlib.pyplot as plt

# Image processing
from PIL import Image, ImageDraw
import requests

# Diffusion models
from diffusers import (
    StableDiffusionInpaintPipeline,
    AutoPipelineForInpainting
)

---

### Device Setup

In [24]:
def setup_device():
    """Setup optimal device for diffusion models."""
    if torch.cuda.is_available():
        return "cuda", torch.float16
    elif torch.backends.mps.is_available():
        return "mps", torch.float32
    else:
        return "cpu", torch.float32

---

### Load Image Inpainting Pipeline from HuggingFace

In [25]:
def load_inpainting_pipeline(device="cuda", dtype=torch.float16):
    """Load inpainting pipeline with alternative model."""
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "Lykon/dreamshaper-8-inpainting",
        torch_dtype=dtype,
        safety_checker=None,
        requires_safety_checker=False
    )
    pipe = pipe.to(device)
    
    if hasattr(pipe, 'enable_attention_slicing'):
        pipe.enable_attention_slicing()
    return pipe

---

### Loading Image from URL or File Path

In [26]:
def load_image_from_url(url: str) -> Image.Image:
    """
    Load an image from a URL.
    
    Args:
        url (str): URL of the image
        
    Returns:
        Image.Image: PIL Image object
    """
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        image = Image.open(response.raw)
        return image
    except Exception as e:
        print(f"Error loading image from URL: {e}")
        raise


def load_image_from_path(path: str) -> Image.Image:
    """
    Load an image from a local file path.
    
    Args:
        path (str): Path to the image file
        
    Returns:
        Image.Image: PIL Image object
    """
    try:
        image = Image.open(path)
        return image
    except Exception as e:
        print(f"Error loading image from path: {e}")
        raise


---

### Custom Mask Generation

In [27]:
def create_circular_mask(image, center, radius):
    """Create a circular mask for inpainting."""
    mask = Image.new("L", image.size, 0)
    draw = ImageDraw.Draw(mask)
    left = center[0] - radius
    top = center[1] - radius
    right = center[0] + radius
    bottom = center[1] + radius
    draw.ellipse([left, top, right, bottom], fill=255)
    return mask


def create_rectangular_mask(image: Image.Image, bbox: Tuple[int, int, int, int]) -> Image.Image:
    """
    Create a rectangular mask for inpainting.
    
    Args:
        image (Image.Image): Input image
        bbox (Tuple[int, int, int, int]): Bounding box (left, top, right, bottom)
        
    Returns:
        Image.Image: Binary mask image
    """
    mask = Image.new("L", image.size, 0)
    draw = ImageDraw.Draw(mask)
    
    # Draw white rectangle on black background
    draw.rectangle(bbox, fill=255)
    
    return mask


def create_custom_mask(image: Image.Image, mask_type: str = "center_circle") -> Image.Image:
    """
    Create various types of masks for inpainting.
    
    Args:
        image (Image.Image): Input image
        mask_type (str): Type of mask to create
        
    Returns:
        Image.Image: Binary mask image
    """
    width, height = image.size
    
    if mask_type == "center_circle":
        return create_circular_mask(image, (width//2, height//2), min(width, height)//6)
    elif mask_type == "large_circle":
        return create_circular_mask(image, (width//2, height//2), min(width, height)//4)
    elif mask_type == "top_left":
        return create_rectangular_mask(image, (0, 0, width//3, height//3))
    elif mask_type == "center_square":
        size = min(width, height) // 3
        left = (width - size) // 2
        top = (height - size) // 2
        return create_rectangular_mask(image, (left, top, left + size, top + size))
    elif mask_type == "bottom_strip":
        return create_rectangular_mask(image, (0, height*3//4, width, height))
    else:
        # Default to center circle
        return create_circular_mask(image, (width//2, height//2), min(width, height)//6)


---

### Run Inference

In [28]:
def inpaint_image(pipe, image, mask, prompt, seed=42):
    """Fill masked regions based on text prompt."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    
    image = image.resize((512, 512))
    mask = mask.resize((512, 512))
    
    with torch.no_grad():
        result = pipe(
            prompt=prompt,
            image=image,
            mask_image=mask,
            negative_prompt="blurry, low quality",
            num_inference_steps=20,
            guidance_scale=7.5
        )
    return result.images[0]

---

In [29]:
def visualize_results(base_image, mask, results):
    """Visualize inpainting results."""
    fig, axes = plt.subplots(3, len(results) + 1, figsize=(16, 10))
    
    # Show original and mask in first column
    axes[0, 0].imshow(base_image)
    axes[0, 0].set_title("Original", fontsize=10)
    axes[0, 0].axis('off')
    
    axes[1, 0].imshow(mask, cmap='gray')
    axes[1, 0].set_title("Mask", fontsize=10)
    axes[1, 0].axis('off')
    
    # Create overlay
    overlay = base_image.copy()
    overlay_array = np.array(overlay)
    mask_array = np.array(mask)
    overlay_array[mask_array > 128] = [255, 100, 100]
    overlay = Image.fromarray(overlay_array)
    
    axes[2, 0].imshow(overlay)
    axes[2, 0].set_title("Original + Mask", fontsize=10)
    axes[2, 0].axis('off')
    
    # Show results
    for i, (prompt, image) in enumerate(results):
        col = i + 1
        axes[0, col].imshow(base_image)
        axes[0, col].set_title("Original", fontsize=10)
        axes[0, col].axis('off')
        
        axes[1, col].imshow(mask, cmap='gray')
        axes[1, col].set_title("Mask", fontsize=10)
        axes[1, col].axis('off')
        
        axes[2, col].imshow(image)
        axes[2, col].set_title(prompt[:25] + "...", fontsize=10)
        axes[2, col].axis('off')
    
    plt.tight_layout()
    plt.show()


def create_sample_image():
    """Create a simple landscape image."""
    width, height = 512, 512
    image = Image.new('RGB', (width, height))
    pixels = []
    
    for y in range(height):
        for x in range(width):
            if y < height // 3:  # Sky
                r, g, b = 135, 206, 235
            elif y < height * 2 // 3:  # Mountains
                r = int(139 + (x / width) * 50)
                g = int(69 + (x / width) * 30)
                b = int(19 + (x / width) * 20)
            else:  # Ground
                r = int(34 + (x / width) * 50)
                g = int(139 + (x / width) * 30)
                b = int(34 + (x / width) * 20)
            pixels.append((r, g, b))
    
    image.putdata(pixels)
    return image


In [None]:
# Execute
device, dtype = setup_device()
pipe = load_inpainting_pipeline(device, dtype)

# Create sample image and mask
base_image = create_sample_image()
mask = create_circular_mask(base_image, (256, 256), 100)

prompts = [
    "A beautiful butterfly",
    "A glowing crystal",
    "A small robot"
]

results = []
for i, prompt in enumerate(prompts):
    inpainted = inpaint_image(pipe, base_image, mask, prompt, seed=200+i)
    results.append((prompt, inpainted))

visualize_results(base_image, mask, results) 

---