In [None]:
import torch
from PIL import Image
import numpy as np
from diffusers import RePaintPipeline, RePaintScheduler
import os

def load_image(image_path):
    """Load image from local path"""
    return Image.open(image_path).convert("RGB").resize((256, 256))

def load_mask(mask_path):
    """Load mask from local path"""
    mask = Image.open(mask_path).convert("L").resize((256, 256))
    # Convert mask to binary (0 or 255)
    mask_array = np.array(mask)
    mask_array = (mask_array > 127).astype(np.uint8) * 255
    return Image.fromarray(mask_array)

# Configuration
image_path = "test_images/sample.png"  # Change to your image path
mask_path = "test_masks/sample.png"    # Change to your mask path
output_path = "result_inpainted.png"

num_inference_steps = 250
jump_length = 10
jump_n_sample = 10

# Use CUDA (T4 GPU)
device = "cuda"
print(f"Using device: {device}")

# Load images
print(f"Loading image: {image_path}")
original_image = load_image(image_path)

print(f"Loading mask: {mask_path}")
mask_image = load_mask(mask_path)

# Load the RePaint scheduler and pipeline
print("Loading RePaint model from Hugging Face...")
print("This will download ~2GB on first run...")
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
pipe = pipe.to(device)

# Set generator for reproducibility
generator = torch.Generator(device=device).manual_seed(0)

# Run inpainting
print(f"Running inpainting (steps: {num_inference_steps}, jump_length: {jump_length}, jump_n_sample: {jump_n_sample})...")
output = pipe(
    original_image=original_image,
    mask_image=mask_image,
    num_inference_steps=num_inference_steps,
    eta=0.0,
    jump_length=jump_length,
    jump_n_sample=jump_n_sample,
    generator=generator,
)

# Save result
inpainted_image = output.images[0]
inpainted_image.save(output_path)
print(f"Saved result to: {output_path}")

# Display result
from IPython.display import display
display(original_image)
display(mask_image)
display(inpainted_image)


## 5. Download Results
