# Stable Diffusion Demo 

This notebook demonstrates the Stable Diffusion pipeline with visualization of latent representations at each denoising step.

In [1]:
import model_loader
import pipeline
from PIL import Image
from transformers import CLIPTokenizer
import torch

DEVICE = "cpu"

ALLOW_CUDA = True
ALLOW_MPS = False

if torch.cuda.is_available() and ALLOW_CUDA:
    DEVICE = "cuda"
elif (torch.has_mps or torch.backends.mps.is_available()) and ALLOW_MPS:
    DEVICE = "mps"
print(f"Using device: {DEVICE}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# Load models and tokenizer
tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
model_file = "../data/v1-5-pruned-emaonly.ckpt"
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
print("Models loaded successfully!")

Models loaded successfully!


In [3]:
# Configuration
prompt = "A dog wearing a red scarf, sitting in a dreamy flower field at golden hour, highly detailed, realistic style"
uncond_prompt = "do not change the dog's face, pose"
do_cfg = True
cfg_scale = 8  # min: 1, max: 14

# Image to image (optional)
input_image = None
# Uncomment to enable image to image
# image_path = "../images/dog.jpg"
# input_image = Image.open(image_path).convert("RGB")
# input_image.resize((512, 512))
strength = 0.8

# Sampler settings
sampler = "ddpm"
num_inference_steps = 18  # Reduced for faster execution and more frequent visualization
seed = 42

print(f"Prompt: {prompt}")
print(f"Steps: {num_inference_steps}")
print(f"Seed: {seed}")
print(f"CFG Scale: {cfg_scale}")

Prompt: A dog wearing a red scarf, sitting in a dreamy flower field at golden hour, highly detailed, realistic style
Steps: 18
Seed: 42
CFG Scale: 8


In [None]:
import threading
import gradio as gr
from PIL import Image
import pipeline
import time
import numpy as np
import torch

# Global variable to control cancellation
cancel_flag = threading.Event()

# Global variable to store intermediate results
intermediate_results = []

def generate_image_realtime(prompt, input_image=None, strength=0.8, cfg_scale=8, num_inference_steps=18, seed=42, sampler="ddpm", progress=gr.Progress()):
    """Generate images with true real-time updates showing each step of the process"""
    global intermediate_results
    
    # Reset variables
    cancel_flag.clear()
    intermediate_results = []
    
    # Process input image if provided
    if input_image:
        input_image = input_image.convert("RGB").resize((512, 512))
    else:
        input_image = None
    
    # Set deterministic seed
    if seed is None or seed < 0:
        seed = int(torch.randint(0, 2**32 - 1, (1,)).item())
        print(f"Using random seed: {seed}")
    
    # Create a progress callback that stores each intermediate image
    def store_progress(step, image_array):
        """Store each intermediate image for display"""
        # Debug: Check image array values
        print(f"Step {step}: min={image_array.min():.4f}, max={image_array.max():.4f}, shape={image_array.shape}")
        
        # Update progress bar
        progress((step) / num_inference_steps, desc=f"Step {step}/{num_inference_steps}")
        
        # Store the image - ensure values are in proper range for PIL
        image_array_clipped = np.clip(image_array, 0, 255).astype(np.uint8)
        intermediate_results.append(Image.fromarray(image_array_clipped))
    
    # Create placeholder image for initial display
    placeholder = np.zeros((512, 512, 3), dtype=np.uint8)
    placeholder_img = Image.fromarray(placeholder)
    yield placeholder_img
    
    # Run the generation with our callback
    output_image = pipeline.generate(
        prompt=prompt,
        uncond_prompt="",
        input_image=input_image,
        strength=strength,
        do_cfg=True,
        cfg_scale=cfg_scale,
        sampler_name=sampler,  # Use the selected sampler
        n_inference_steps=num_inference_steps,
        seed=seed,
        models=models,
        device=DEVICE,
        idle_device="cpu",
        tokenizer=tokenizer,
        cancel_flag=cancel_flag,
        progress_callback=store_progress
    )
    
    # Now yield each intermediate result with a slight delay
    # This creates a smooth visualization of the denoising process
    for i, img in enumerate(intermediate_results):
        progress((i + 1) / num_inference_steps, desc=f"Visualizing step {i+1}/{num_inference_steps}")
        yield img
    
    # Finally show the complete image
    if output_image is not None:
        progress(1.0, desc="Generation complete!")
        # Debug: Check final output values
        print(f"Final output: min={output_image.min():.4f}, max={output_image.max():.4f}, shape={output_image.shape}")
        
        # Make sure the final image is in proper range for PIL
        output_image_clipped = np.clip(output_image, 0, 255).astype(np.uint8)
        yield Image.fromarray(output_image_clipped)
    else:
        # If cancelled, show the last intermediate result
        if intermediate_results:
            yield intermediate_results[-1]
        else:
            # No results, show black image
            yield placeholder_img

def cancel_generation():
    """Cancel the generation process"""
    cancel_flag.set()
    return None

# Create the Gradio interface
with gr.Blocks(title="Stable Diffusion Real-Time Preview") as demo:
    gr.Markdown("# 🎨 Stable Diffusion with Real-Time Visualization")
    gr.Markdown("Watch the image evolve step-by-step as the denoising process runs!")
    
    with gr.Row():
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(
                label="Prompt", 
                placeholder="Enter your prompt here...",
                value="A dog wearing a red scarf, sitting in a dreamy flower field at golden hour, highly detailed, realistic style",
                lines=3
            )
            image_input = gr.Image(label="Input Image (optional)", type="pil")
            
            with gr.Row():
                strength_slider = gr.Slider(0.1, 1.0, value=0.8, step=0.1, label="Strength")
                cfg_scale_slider = gr.Slider(1, 14, value=8, step=1, label="CFG Scale")
            
            with gr.Row():
                steps_slider = gr.Slider(5, 50, value=18, step=1, label="Number of Steps")
                seed_input = gr.Number(value=42, label="Seed")
            
            # Add sampler selection
            sampler_dropdown = gr.Dropdown(
                choices=["ddpm", "ddim"], 
                value="ddpm", 
                label="Sampler",
                info="DDPM is higher quality but slower. DDIM is faster but may have slight quality differences."
            )
            
            with gr.Row():
                generate_button = gr.Button("🎨 Generate", variant="primary")
                cancel_button = gr.Button("❌ Cancel", variant="secondary")
            
            # Debug mode checkbox
            debug_mode = gr.Checkbox(value=False, label="Enable Debug Mode", info="Print debug info to console")
            
        with gr.Column(scale=1):
            # This is the image component that will be updated in-place
            live_preview = gr.Image(
                label="Live Generation Preview",
                height=512,
                width=512
            )
    
            # Add helpful instructions
            gr.Markdown("""
            ### 🎮 How to Use
            1. Enter a detailed prompt
            2. Set your generation parameters
            3. Select a sampler:
               - **DDPM**: Original sampler, high quality but slower
               - **DDIM**: Faster sampler, fewer steps needed for similar quality
            4. Click "Generate" and watch the image evolve step by step!
            """)
    
    # Connect the buttons to functions
    generate_button.click(
        generate_image_realtime,
        inputs=[prompt_input, image_input, strength_slider, cfg_scale_slider, steps_slider, seed_input, sampler_dropdown],
        outputs=live_preview
    )
    
    cancel_button.click(
        cancel_generation,
        outputs=live_preview
    )

    # Add example prompts
    gr.Markdown("### 🌟 Example Prompts")
    example_prompts = [
        "A majestic lion in a savanna at sunset, photorealistic, 4k",
        "A futuristic city with flying cars, cyberpunk style, neon lights",
        "A peaceful Japanese garden with cherry blossoms, watercolor painting style",
        "A dragon flying over mountains, fantasy art, detailed scales",
        "A vintage bicycle in a flower field, soft lighting, film photography"
    ]
    
    gr.Examples(
        examples=[[prompt] for prompt in example_prompts],
        inputs=[prompt_input],
    )

    # Add sampler comparison
    gr.Markdown("""
    ### 📊 Sampler Comparison
    
    #### DDPM (Denoising Diffusion Probabilistic Models)
    - The original sampler used in most diffusion models
    - More accurate and higher quality results
    - Slower as it requires more steps (25-50 recommended)
    - Better for final high-quality images
    
    #### DDIM (Denoising Diffusion Implicit Models)
    - Accelerated deterministic sampler
    - Faster generation with fewer steps (10-25 recommended)
    - Good quality with fewer steps
    - Better when generation speed is important
    - Deterministic results (same seed always gives same output)
    
    Try both samplers with different step counts to find the best balance of quality and speed!
    """)

demo.queue().launch(share=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://c1786ff5e0549e5674.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
* Running on public URL: https://c1786ff5e0549e5674.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




100%|██████████| 18/18 [00:06<00:00,  2.76it/s]

100%|██████████| 18/18 [00:06<00:00,  2.69it/s]

100%|██████████| 18/18 [00:06<00:00,  2.81it/s]



In [None]:
# Compare different samplers side by side
import matplotlib.pyplot as plt
import time
import numpy as np

def compare_samplers(prompt, num_inference_steps=20, seed=42):
    """
    Generate and compare images using different samplers with the same prompt and seed
    """
    print(f"Comparing samplers for prompt: '{prompt}' with {num_inference_steps} steps and seed {seed}")
    start_time = time.time()
    
    # Collect intermediate images for comparison
    ddpm_intermediates = []
    ddim_intermediates = []
    
    # Define callbacks to collect intermediate steps
    def ddpm_callback(step, image):
        print(f"DDPM Step {step}: min={image.min():.2f}, max={image.max():.2f}")
        ddpm_intermediates.append((step, image.copy()))
        
    def ddim_callback(step, image):
        print(f"DDIM Step {step}: min={image.min():.2f}, max={image.max():.2f}")
        ddim_intermediates.append((step, image.copy()))
    
    # Generate with DDPM
    print("\n[1/2] Generating with DDPM sampler...")
    ddpm_start = time.time()
    try:
        ddpm_image = pipeline.generate(
            prompt=prompt,
            uncond_prompt="",
            input_image=None,
            strength=0.8,
            do_cfg=True,
            cfg_scale=7.5,
            sampler_name="ddpm",
            n_inference_steps=num_inference_steps,
            seed=seed,
            models=models,
            device=DEVICE,
            idle_device="cpu",
            tokenizer=tokenizer,
            progress_callback=ddpm_callback,
        )
        ddpm_time = time.time() - ddpm_start
        print(f"✓ DDPM completed in {ddpm_time:.2f} seconds")
    except Exception as e:
        print(f"❌ DDPM generation failed: {e}")
        ddpm_image = np.zeros((512, 512, 3), dtype=np.uint8)
        ddpm_time = time.time() - ddpm_start
    
    # Generate with DDIM
    print("\n[2/2] Generating with DDIM sampler...")
    ddim_start = time.time()
    try:
        ddim_image = pipeline.generate(
            prompt=prompt,
            uncond_prompt="",
            input_image=None,
            strength=0.8,
            do_cfg=True,
            cfg_scale=7.5,
            sampler_name="ddim",
            n_inference_steps=num_inference_steps,
            seed=seed,
            models=models,
            device=DEVICE,
            idle_device="cpu",
            tokenizer=tokenizer,
            progress_callback=ddim_callback,
        )
        ddim_time = time.time() - ddim_start
        print(f"✓ DDIM completed in {ddim_time:.2f} seconds")
    except Exception as e:
        print(f"❌ DDIM generation failed: {e}")
        ddim_image = np.zeros((512, 512, 3), dtype=np.uint8)
        ddim_time = time.time() - ddim_start
    
    # Calculate speedup
    if ddpm_time > 0 and ddim_time > 0:
        speedup = ddpm_time / ddim_time
        print(f"\nDDIM is {speedup:.2f}x faster than DDPM with the same number of steps")
    
    # Debug: Check final image values
    if ddpm_image is not None:
        print(f"DDPM final image: min={ddpm_image.min():.2f}, max={ddpm_image.max():.2f}, shape={ddpm_image.shape}")
    if ddim_image is not None:
        print(f"DDIM final image: min={ddim_image.min():.2f}, max={ddim_image.max():.2f}, shape={ddim_image.shape}")
    
    # Make sure images are properly normalized
    if ddpm_image is not None:
        ddpm_image_display = np.clip(ddpm_image, 0, 255).astype(np.uint8)
    else:
        ddpm_image_display = np.zeros((512, 512, 3), dtype=np.uint8)
        
    if ddim_image is not None:
        ddim_image_display = np.clip(ddim_image, 0, 255).astype(np.uint8)
    else:
        ddim_image_display = np.zeros((512, 512, 3), dtype=np.uint8)
    
    # Display the results side by side
    plt.figure(figsize=(18, 12))
    
    plt.subplot(2, 2, 1)
    plt.imshow(ddpm_image_display)
    plt.title(f"DDPM Sampler\n{num_inference_steps} steps, {ddpm_time:.2f}s")
    plt.axis('off')
    
    plt.subplot(2, 2, 2)
    plt.imshow(ddim_image_display)
    plt.title(f"DDIM Sampler\n{num_inference_steps} steps, {ddim_time:.2f}s")
    plt.axis('off')
    
    # Show intermediate steps comparison
    if ddpm_intermediates and ddim_intermediates:
        # Show middle step
        mid_step_ddpm = len(ddpm_intermediates) // 2
        mid_step_ddim = len(ddim_intermediates) // 2
        
        plt.subplot(2, 2, 3)
        if mid_step_ddpm < len(ddpm_intermediates):
            plt.imshow(np.clip(ddpm_intermediates[mid_step_ddpm][1], 0, 255).astype(np.uint8))
            plt.title(f"DDPM - Step {ddpm_intermediates[mid_step_ddpm][0]}")
        plt.axis('off')
        
        plt.subplot(2, 2, 4)
        if mid_step_ddim < len(ddim_intermediates):
            plt.imshow(np.clip(ddim_intermediates[mid_step_ddim][1], 0, 255).astype(np.uint8))
            plt.title(f"DDIM - Step {ddim_intermediates[mid_step_ddim][0]}")
        plt.axis('off')
    
    plt.suptitle(f"Sampler Comparison\nPrompt: '{prompt}' - Seed: {seed}", fontsize=16)
    plt.tight_layout()
    plt.show()
    
    return ddpm_image, ddim_image, ddpm_time, ddim_time

# Uncomment to run the sampler comparison
# compare_samplers("A magical castle on a floating island with waterfalls, fantasy art style", num_inference_steps=20)

In [None]:
# Create a Gradio interface for sampler comparison
def create_sampler_comparison_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# Stable Diffusion Sampler Comparison")
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", value="A magical castle on a floating island with waterfalls, fantasy art style")
                steps_slider = gr.Slider(minimum=5, maximum=50, value=20, step=1, label="Number of steps")
                seed_number = gr.Number(value=42, label="Seed", precision=0)
                compare_btn = gr.Button("Compare Samplers")
        
        with gr.Row():
            with gr.Column():
                gr.Markdown("## DDPM Sampler")
                ddpm_output = gr.Image(label="DDPM Output")
                ddpm_time = gr.Textbox(label="DDPM Generation Time")
            with gr.Column():
                gr.Markdown("## DDIM Sampler")
                ddim_output = gr.Image(label="DDIM Output")
                ddim_time = gr.Textbox(label="DDIM Generation Time")
                
        with gr.Row():
            result_text = gr.Textbox(label="Comparison Results")
        
        def run_comparison(prompt_text, num_steps, seed_val):
            try:
                ddpm_img, ddim_img, ddpm_time_val, ddim_time_val = compare_samplers(
                    prompt=prompt_text, 
                    num_inference_steps=int(num_steps), 
                    seed=int(seed_val)
                )
                
                speedup = ddpm_time_val / ddim_time_val if ddim_time_val > 0 else 0
                result = f"DDPM took {ddpm_time_val:.2f}s, DDIM took {ddim_time_val:.2f}s\n"
                result += f"DDIM is {speedup:.2f}x faster than DDPM with {num_steps} steps\n"
                
                # Verify image quality
                ddpm_std = np.std(ddpm_img) if ddpm_img is not None else 0
                ddim_std = np.std(ddim_img) if ddim_img is not None else 0
                
                result += f"\nImage statistics:\n"
                result += f"DDPM: min={np.min(ddpm_img):.2f}, max={np.max(ddpm_img):.2f}, std={ddpm_std:.2f}\n"
                result += f"DDIM: min={np.min(ddim_img):.2f}, max={np.max(ddim_img):.2f}, std={ddim_std:.2f}"
                
                # Make sure images are properly formatted for display
                ddpm_img_display = np.clip(ddpm_img, 0, 255).astype(np.uint8) if ddpm_img is not None else np.zeros((512, 512, 3), dtype=np.uint8)
                ddim_img_display = np.clip(ddim_img, 0, 255).astype(np.uint8) if ddim_img is not None else np.zeros((512, 512, 3), dtype=np.uint8)
                
                return [
                    ddpm_img_display, 
                    f"{ddpm_time_val:.2f}s",
                    ddim_img_display, 
                    f"{ddim_time_val:.2f}s",
                    result
                ]
            except Exception as e:
                return [
                    None,
                    "Error",
                    None,
                    "Error",
                    f"Error during generation: {str(e)}"
                ]
        
        compare_btn.click(
            fn=run_comparison,
            inputs=[prompt, steps_slider, seed_number],
            outputs=[ddpm_output, ddpm_time, ddim_output, ddim_time, result_text]
        )
        
    return demo

# Create and launch the sampler comparison interface
sampler_comparison_demo = create_sampler_comparison_interface()
# Uncomment to launch the interface
# sampler_comparison_demo.launch()

In [None]:
# Launch the sampler comparison interface
sampler_comparison_demo.launch(debug=True)

In [None]:
# Detailed debugging of the DDIM sampler
def debug_ddim_sampler(n_inference_steps=20, seed=42):
    """
    Function to debug the DDIM sampler by visualizing intermediate latents and checking their statistics
    """
    print(f"Debugging DDIM sampler with {n_inference_steps} steps and seed {seed}")
    
    # Setup DDIM sampler with specified seed
    generator = torch.Generator(device=DEVICE)
    generator.manual_seed(seed)
    
    # Initialize with eta=0.0 for deterministic sampling
    ddim_sampler = DDIMSampler(generator, eta=0.0)
    ddim_sampler.set_inference_timesteps(n_inference_steps)
    
    # Setup DDPM sampler with same seed for comparison
    generator2 = torch.Generator(device=DEVICE)
    generator2.manual_seed(seed)
    ddpm_sampler = DDPMSampler(generator2)
    ddpm_sampler.set_inference_timesteps(n_inference_steps)
    
    # Initialize random latents
    latents_shape = (1, 4, 64, 64)  # Standard SD latent shape
    latents_ddim = torch.randn(latents_shape, generator=generator, device=DEVICE)
    latents_ddpm = latents_ddim.clone()  # Use same initial noise
    
    # Print initial latent statistics
    print(f"Initial latents - min: {latents_ddim.min().item():.4f}, max: {latents_ddim.max().item():.4f}, mean: {latents_ddim.mean().item():.4f}, std: {latents_ddim.std().item():.4f}")
    
    # Dummy model output function (just for testing)
    dummy_model = lambda x: torch.zeros_like(x)
    
    # Visualize the starting noise
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(latents_ddim[0, 0].detach().cpu().numpy(), cmap='viridis')
    plt.title("Initial Noise (Channel 0)")
    plt.colorbar()
    
    # Test the samplers with a simple denoising loop using a dummy model
    print("\nTesting DDIM sampler with dummy model output...")
    for i, timestep in enumerate(tqdm(ddim_sampler.timesteps)):
        model_output = dummy_model(latents_ddim)
        latents_ddim = ddim_sampler.step(timestep, latents_ddim, model_output)
        
        # Print statistics every few steps
        if i % (n_inference_steps // 5) == 0 or i == len(ddim_sampler.timesteps) - 1:
            print(f"DDIM Step {i+1}/{len(ddim_sampler.timesteps)} - min: {latents_ddim.min().item():.4f}, max: {latents_ddim.max().item():.4f}, mean: {latents_ddim.mean().item():.4f}, std: {latents_ddim.std().item():.4f}")
    
    # Compare with DDPM
    print("\nTesting DDPM sampler with dummy model output...")
    for i, timestep in enumerate(tqdm(ddpm_sampler.timesteps)):
        model_output = dummy_model(latents_ddpm)
        latents_ddpm = ddpm_sampler.step(timestep, latents_ddpm, model_output)
        
        # Print statistics every few steps
        if i % (n_inference_steps // 5) == 0 or i == len(ddpm_sampler.timesteps) - 1:
            print(f"DDPM Step {i+1}/{len(ddpm_sampler.timesteps)} - min: {latents_ddpm.min().item():.4f}, max: {latents_ddpm.max().item():.4f}, mean: {latents_ddpm.mean().item():.4f}, std: {latents_ddpm.std().item():.4f}")
    
    # Visualize the final latents (channel 0) for both samplers
    plt.subplot(1, 3, 2)
    plt.imshow(latents_ddim[0, 0].detach().cpu().numpy(), cmap='viridis')
    plt.title("Final DDIM Latents (Channel 0)")
    plt.colorbar()
    
    plt.subplot(1, 3, 3)
    plt.imshow(latents_ddpm[0, 0].detach().cpu().numpy(), cmap='viridis')
    plt.title("Final DDPM Latents (Channel 0)")
    plt.colorbar()
    
    plt.tight_layout()
    plt.show()
    
    return latents_ddim, latents_ddpm

# Run the debug function
# debug_latents_ddim, debug_latents_ddpm = debug_ddim_sampler(n_inference_steps=20)

In [None]:
# Run a side-by-side comparison of DDPM vs DDIM
prompt = "A beautiful mountain landscape with a lake and trees, digital art"
print(f"Generating images with prompt: '{prompt}'")

# Run the comparison
ddpm_image, ddim_image, ddpm_time, ddim_time = compare_samplers(
    prompt=prompt, 
    num_inference_steps=20,
    seed=42
)

# Show speedup
if ddpm_time > 0 and ddim_time > 0:
    print(f"DDIM is {ddpm_time / ddim_time:.2f}x faster than DDPM with the same number of steps")

# Show image statistics for debugging
def print_img_stats(name, img):
    if img is not None:
        print(f"{name} image stats: min={img.min():.2f}, max={img.max():.2f}, mean={img.mean():.2f}, std={img.std():.2f}")
    else:
        print(f"{name} image is None")

print_img_stats("DDPM", ddpm_image)
print_img_stats("DDIM", ddim_image)

In [None]:
# Test image-to-image generation with DDIM
from PIL import Image
import numpy as np

# First, let's add the missing set_strength method to DDIM
def add_set_strength_to_ddim():
    # Only add if not already present
    if not hasattr(DDIMSampler, 'set_strength'):
        code = """
    def set_strength(self, strength: float = 1.0) -> None:
        \"\"\"
        Set the denoising strength for image-to-image generation.
        
        This method is used in img2img pipelines where we start from an existing
        image rather than pure noise. The strength parameter controls how much
        of the original image structure is preserved.
        
        Args:
            strength (float): Denoising strength between 0.0 and 1.0
                            - 1.0: Start from pure noise (like txt2img)
                            - 0.8: Add significant noise, major changes
                            - 0.5: Moderate changes to original image
                            - 0.2: Minor changes, preserve most structure
                            - 0.0: No changes (return original image)
        \"\"\"
        # Calculate how many initial denoising steps to skip
        # Higher strength = fewer skipped steps = more denoising = more changes
        start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
        
        # Skip the initial timesteps (start from partially noised image)
        self.timesteps = self.timesteps[start_step:]
        self.start_step = start_step
        
        print(f"Set DDIM strength to {strength}, starting from step {start_step}/{self.num_inference_steps}")
        """
        # Add the method to the class
        import types
        DDIMSampler.set_strength = types.MethodType(
            lambda self, strength=1.0: exec(code), 
            DDIMSampler
        )
        print("Added set_strength method to DDIMSampler")

# Add the method
add_set_strength_to_ddim()

# Now let's try image-to-image with DDIM
def test_img2img():
    # Create a simple gradient image as input
    width, height = 512, 512
    gradient = np.zeros((height, width, 3), dtype=np.uint8)
    for y in range(height):
        for x in range(width):
            # Create a simple gradient
            gradient[y, x, 0] = int(255 * x / width)  # R increases from left to right
            gradient[y, x, 1] = int(255 * y / height)  # G increases from top to bottom
            gradient[y, x, 2] = 128  # Constant blue
    
    # Convert to PIL image
    input_img = Image.fromarray(gradient)
    
    # Display the input image
    plt.figure(figsize=(5, 5))
    plt.imshow(input_img)
    plt.title("Input Image (Gradient)")
    plt.axis('off')
    plt.show()
    
    # Run image-to-image with both samplers
    prompt = "A beautiful sunset over mountains"
    strength = 0.75
    print(f"Running img2img with prompt: '{prompt}', strength: {strength}")
    
    # Generate with DDPM
    print("Generating with DDPM sampler...")
    ddpm_start = time.time()
    try:
        ddpm_image = pipeline.generate(
            prompt=prompt,
            uncond_prompt="",
            input_image=input_img,  # Pass the input image
            strength=strength,      # Set the strength
            do_cfg=True,
            cfg_scale=7.5,
            sampler_name="ddpm",
            n_inference_steps=20,
            seed=42,
            models=models,
            device=DEVICE,
            idle_device="cpu",
            tokenizer=tokenizer,
        )
        ddpm_time = time.time() - ddpm_start
        print(f"✓ DDPM completed in {ddpm_time:.2f} seconds")
    except Exception as e:
        print(f"❌ DDPM generation failed: {e}")
        ddpm_image = np.zeros((512, 512, 3), dtype=np.uint8)
        ddpm_time = time.time() - ddpm_start
    
    # Generate with DDIM
    print("Generating with DDIM sampler...")
    ddim_start = time.time()
    try:
        ddim_image = pipeline.generate(
            prompt=prompt,
            uncond_prompt="",
            input_image=input_img,  # Pass the input image
            strength=strength,      # Set the strength
            do_cfg=True,
            cfg_scale=7.5,
            sampler_name="ddim",
            n_inference_steps=20,
            seed=42,
            models=models,
            device=DEVICE,
            idle_device="cpu",
            tokenizer=tokenizer,
        )
        ddim_time = time.time() - ddim_start
        print(f"✓ DDIM completed in {ddim_time:.2f} seconds")
    except Exception as e:
        print(f"❌ DDIM generation failed: {e}")
        ddim_image = np.zeros((512, 512, 3), dtype=np.uint8)
        ddim_time = time.time() - ddim_start
    
    # Display the results side by side
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(input_img)
    plt.title("Original Image")
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(ddpm_image)
    plt.title(f"DDPM Result ({ddpm_time:.2f}s)")
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(ddim_image)
    plt.title(f"DDIM Result ({ddim_time:.2f}s)")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return ddpm_image, ddim_image

# Uncomment to run the test
# img2img_ddpm, img2img_ddim = test_img2img()

In [None]:
# Add a proper set_strength method to the DDIM sampler file
set_strength_code = """    def set_strength(self, strength: float = 1.0) -> None:
        \"\"\"
        Set the denoising strength for image-to-image generation.
        
        This method is used in img2img pipelines where we start from an existing
        image rather than pure noise. The strength parameter controls how much
        of the original image structure is preserved.
        
        Args:
            strength (float): Denoising strength between 0.0 and 1.0
                            - 1.0: Start from pure noise (like txt2img)
                            - 0.8: Add significant noise, major changes
                            - 0.5: Moderate changes to original image
                            - 0.2: Minor changes, preserve most structure
                            - 0.0: No changes (return original image)
        \"\"\"
        # Calculate how many initial denoising steps to skip
        # Higher strength = fewer skipped steps = more denoising = more changes
        start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
        
        # Skip the initial timesteps (start from partially noised image)
        self.timesteps = self.timesteps[start_step:]
        self.start_step = start_step
        
        print(f"Set DDIM strength to {strength}, starting from step {start_step}/{self.num_inference_steps}")
"""

import os
ddim_path = os.path.join("/home/shora/Research/stable-diffusion/sd/ddim.py")

# Check if the method already exists in the file
with open(ddim_path, 'r') as f:
    ddim_content = f.read()

if "def set_strength" not in ddim_content:
    # Find the last method in the class
    last_method_index = ddim_content.rfind("def ")
    # Find the end of this method by locating the next line with less indentation
    lines = ddim_content.splitlines()
    method_line = None
    for i, line in enumerate(lines):
        if "def " in line and "def set_strength" not in line:
            method_line = i
            
    if method_line is not None:
        # Find the end of this method
        indentation = len(lines[method_line]) - len(lines[method_line].lstrip())
        end_line = None
        for i in range(method_line + 1, len(lines)):
            if lines[i].strip() and len(lines[i]) - len(lines[i].lstrip()) <= indentation:
                end_line = i
                break
        
        if end_line is None:
            end_line = len(lines)
        
        # Insert the new method
        lines.insert(end_line, set_strength_code)
        new_content = "\n".join(lines)
        
        # Write back to the file
        with open(ddim_path, 'w') as f:
            f.write(new_content)
        
        print("Added set_strength method to DDIM sampler file")
    else:
        print("Could not find a suitable location to add the method")
else:
    print("set_strength method already exists in the DDIM sampler file")

# DDPM vs DDIM Sampler Comparison

This notebook demonstrates the implementation and comparison of two diffusion model samplers:

## 1. DDPM (Denoising Diffusion Probabilistic Models)
- **Nature**: Stochastic sampling process
- **Algorithm**: Follows the original DDPM paper by Ho et al., 2020
- **Behavior**: Adds random noise at each sampling step
- **Performance**: 
  - Higher quality, especially for complex details
  - More coherent outputs
  - Slower generation (more compute-intensive)

## 2. DDIM (Denoising Diffusion Implicit Models)
- **Nature**: Deterministic sampling process (when η=0)
- **Algorithm**: Follows the DDIM paper by Song et al., 2020
- **Behavior**: Takes larger steps with deterministic updates
- **Performance**:
  - Faster generation (typically 2-4x faster than DDPM)
  - Requires fewer inference steps for similar quality
  - Can be made stochastic by adjusting η parameter (0=deterministic, 1=DDPM-like)
  - Allows for controlled interpolation in latent space

## Implementation Details

Both samplers follow the same general approach:
1. Start with random noise
2. Iteratively apply denoising steps to reach clean image
3. Use U-Net to predict noise at each step

The key differences are in how the sampling steps are calculated:
- DDPM uses the full Markovian stochastic process
- DDIM uses a non-Markovian deterministic process that can skip steps

## Debug Notes

If you encounter issues with the DDIM sampler:
1. Check for numerical stability issues in the calculation
2. Ensure proper tensor device handling
3. Add small epsilon values to prevent division by zero or sqrt of negative values
4. Monitor for NaN/Inf values in the generation process

You can use the debug tools in this notebook to compare outputs and diagnose potential issues.