Image-to-Image Transformation with Stable Diffusion
==================================================

[View on Google Colab](https://colab.research.google.com/drive/1q00kVpnYnKZmYNICE05_1NHC8bSUFQH0?usp=sharing)

### Import the necessary libraries

In [1]:
# !pip install torch matplotlib pillow diffusers

import torch
import matplotlib.pyplot as plt
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
import requests
from io import BytesIO

---

### Setup Device

In [2]:
def setup_device():
    """Setup optimal device for diffusion models.
    
    Returns:
        tuple: Device string and appropriate dtype for the device.
    """
    if torch.cuda.is_available():
        return "cuda", torch.float16
    elif torch.backends.mps.is_available():
        return "mps", torch.float32
    else:
        return "cpu", torch.float32

---


### Load Image to Image Pipeline from HuggingFace

In [3]:
def load_img2img_pipeline(device="cuda", dtype=torch.float16):
    """Load Stable Diffusion pipeline for image-to-image generation.
    
    Args:
        device (str): Device to load the pipeline on.
        dtype: Data type for the pipeline.
        
    Returns:
        StableDiffusionImg2ImgPipeline: Loaded pipeline.
    """
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=dtype,
        safety_checker=None,
        requires_safety_checker=False
    )
    pipe = pipe.to(device)
    if hasattr(pipe, 'enable_attention_slicing'):
        pipe.enable_attention_slicing()
    return pipe

In [4]:
def transform_image(pipe, init_image, prompt, strength=0.7, seed=42):
    """Transform an existing image based on a text prompt.
    
    Args:
        pipe: The diffusion pipeline.
        init_image (PIL.Image): Initial image to transform.
        prompt (str): Text prompt for transformation.
        strength (float): Transformation strength (0-1).
        seed (int): Random seed for reproducibility.
        
    Returns:
        PIL.Image: Transformed image.
    """
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    
    init_image = init_image.resize((512, 512))
    
    with torch.no_grad():
        result = pipe(
            prompt=prompt,
            image=init_image,
            negative_prompt="blurry, low quality",
            strength=strength,
            num_inference_steps=20,
            guidance_scale=7.5
        )
    return result.images[0]

---

### Run Inference

In [5]:
def download_coco_image(url):
    """Download a real image from the COCO dataset.
    
    Args:
        url (str): URL of the COCO image to download.
        
    Returns:
        PIL.Image: Downloaded image.
    """
    try:
        response = requests.get(url)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        return image.convert('RGB')
    except Exception as e:
        print(f"Error downloading image: {e}")

In [6]:
def visualize_results(init_image, results):
    """Visualize transformation results.
    
    Args:
        init_image (PIL.Image): Original image.
        results (list): List of (prompt, transformed_image) tuples.
    """
    fig, axes = plt.subplots(2, len(results), figsize=(15, 8))
    
    # Show original image in top row
    for i in range(len(results)):
        axes[0, i].imshow(init_image)
        axes[0, i].set_title("Original COCO Image", fontsize=10)
        axes[0, i].axis('off')
    
    # Show transformed images in bottom row
    for i, (prompt, image) in enumerate(results):
        axes[1, i].imshow(image)
        axes[1, i].set_title(prompt[:30] + "...", fontsize=10)
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Execute
device, dtype = setup_device()
pipe = load_img2img_pipeline(device, dtype)

coco_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
print(f"Downloading COCO image from: {coco_image_url}")
init_image = download_coco_image(coco_image_url)

prompts = [
    "Transform into a watercolor painting",
    "Make it look like a vintage photograph",
    "Convert to anime art style"
]

results = []
for i, prompt in enumerate(prompts):
    print(f"Processing prompt {i+1}/{len(prompts)}: {prompt}")
    transformed = transform_image(pipe, init_image, prompt, seed=100+i)
    results.append((prompt, transformed))

visualize_results(init_image, results) 

---