# Image Generator with Compel

This notebook enhances the basic generator with:
- **Compel Integration**: Handle long prompts with weighted embeddings (no 77 token limit)
- **Batch Prompt Fetching**: Pre-fetch multiple prompts to reduce API calls
- **Overflow Management**: Smart usage of overflow tags from prompt API
- **Memory Optimized**: Works on Colab free tier

In [11]:
import os
from dotenv import load_dotenv
from google.colab import drive

mount_path = '/content/drive'

if not os.path.exists(mount_path):
    print("Drive not mounted. Mounting now...")
    drive.mount(mount_path)
else:
    print("Drive already mounted.")

# Define the path to your file
file_path = '/content/drive/MyDrive/AI/hf_token.env'

# Load the environment variables
load_dotenv(file_path)

huggingface_token = os.getenv('HUGGINGFACE_TOKEN')

# Suppress warnings
import warnings
warnings.filterwarnings('ignore', message='Flax classes are deprecated')
warnings.filterwarnings('ignore', category=FutureWarning)

Drive already mounted.


In [12]:
# Install required packages
!pip install -q diffusers transformers accelerate safetensors omegaconf invisible-watermark compel bitsandbytes

In [13]:
import requests
from datetime import datetime
import sys
import os
import gc
from PIL import Image
from io import BytesIO
import random
import json
from collections import defaultdict
import torch
from huggingface_hub import snapshot_download, login
from diffusers import (
    StableDiffusionXLPipeline,
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    HeunDiscreteScheduler,
    UniPCMultistepScheduler
)
from compel import CompelForSDXL

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Initial GPU Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4
Initial GPU Memory: 7.61GB free


## Configuration

In [14]:
#@markdown ### Paths and Authentication
base_path = "/content/drive/MyDrive/AI/" #@param {type:"string"}

#@markdown ### Model Configuration
model_id = "John6666/easy-sfw-nsfw-pony-cartooneasyv1-sdxl" #@param ["stablediffusionapi/mklan-xxx-nsfw-pony","stablediffusionapi/duchaiten-real3d-nsfw-xl", "John6666/cyberrealistic-pony-v7-sdxl", "John6666/uber-realistic-porn-merge-xl-urpmxl-v6final-sdxl","John6666/fucktastic-real-checkpoint-pony-pdxl-porn-realistic-nsfw-sfw-21-sdxl","Vvilams/pony-realism-v21main-sdxl","John6666/sexoholic-real-pony-nsfw-v2-sdxl","John6666/wai-ani-nsfw-ponyxl-v11-sdxl", "John6666/duchaiten-pony-real-v20-sdxl", "stable-diffusion-v1-5/stable-diffusion-v1-5", "stabilityai/stable-diffusion-xl-base-1.0", "John6666/pornworks-real-porn-v03-sdxl", "UnfilteredAI/NSFW-GEN-ANIME", "UnfilteredAI/NSFW-gen-v2", "John6666/easy-sfw-nsfw-pony-cartooneasyv1-sdxl", "John6666/wai-nsfw-illustrious-v80-sdxl"]
download_model = False #@param {type:"boolean"}

#@markdown ### Prompt Configuration
base_prompt = "" #@param {type:"string"}
sfw = False #@param {type:"boolean"}
selected_categories = [] #@param {type:"raw"}
prompts_per_batch = 2 #@param {type:"integer"}

#@markdown ### Generation Settings
num_of_images = 50 #@param {type:"integer"}

#@markdown #### Schedulers (select one or more)
use_model_default = False #@param {type:"boolean"}
use_dpm_2m = False #@param {type:"boolean"}
use_dpm_2m_karras = False #@param {type:"boolean"}
use_dpm_sde = False #@param {type:"boolean"}
use_dpm_sde_karras = False #@param {type:"boolean"}
use_euler = True #@param {type:"boolean"}
use_euler_a = False #@param {type:"boolean"}
use_heun = False #@param {type:"boolean"}
use_kdpm2 = True #@param {type:"boolean"}
use_kdpm2_a = False #@param {type:"boolean"}
use_lms = False #@param {type:"boolean"}
use_ddim = True #@param {type:"boolean"}
use_pndm = False #@param {type:"boolean"}
use_unipc = False #@param {type:"boolean"}

guidance_low = 5 #@param {type:"integer"}
guidance_high = 15 #@param {type:"integer"}
steps_low = 25 #@param {type:"integer"}
steps_high = 40 #@param {type:"integer"}

#@markdown ### Resolution Settings
resolution_mode = "sdxl" #@param ["classic", "sdxl"]

# Resolution pools
CLASSIC_RESOLUTIONS = [720, 768, 800, 1024]
SDXL_RESOLUTIONS = [
    (1024, 1024), (1152, 896), (896, 1152), (1216, 832),
    (832, 1216), (1344, 768), (768, 1344)
]

# Build selected schedulers list
selected_schedulers = []
if use_model_default: selected_schedulers.append("Model Default")
if use_dpm_2m: selected_schedulers.append("DPM++ 2M")
if use_dpm_2m_karras: selected_schedulers.append("DPM++ 2M Karras")
if use_dpm_sde: selected_schedulers.append("DPM++ SDE")
if use_dpm_sde_karras: selected_schedulers.append("DPM++ SDE Karras")
if use_euler: selected_schedulers.append("Euler")
if use_euler_a: selected_schedulers.append("Euler a")
if use_heun: selected_schedulers.append("Heun")
if use_kdpm2: selected_schedulers.append("KDPM2")
if use_kdpm2_a: selected_schedulers.append("KDPM2 a")
if use_lms: selected_schedulers.append("LMS")
if use_ddim: selected_schedulers.append("DDIM")
if use_pndm: selected_schedulers.append("PNDM")
if use_unipc: selected_schedulers.append("UniPC")

# Validate at least one scheduler selected
if not selected_schedulers:
    raise ValueError("ERROR: At least one scheduler must be selected!")

# Derived paths
model_path = base_path + "models/" + model_id
save_directory = f"{base_path}images/{datetime.now().strftime('%Y%m%d%H%M%S')}/"
os.makedirs(save_directory, exist_ok=True)

# Models requiring uniform beta_schedule instead of scaled_linear
uniform_models = [
    "John6666/uber-realistic-porn-merge-xl-urpmxl-v6final-sdxl",
    "John6666/sexoholic-real-pony-nsfw-v2-sdxl",
    "John6666/duchaiten-pony-real-v20-sdxl",
    "stabilityai/sdxl-turbo",
    "John6666/pornworks-real-porn-v03-sdxl"
]
use_uniform = model_id in uniform_models

# Fallback negative prompt (API now provides negative_prompt per image)
NEGATIVE_PROMPT = "text, writing, bad teeth, deformed face and eyes, child, childish, young, deformed, uneven eyes, too many fingers"

print(f"✓ Configuration loaded")
print(f"  Model: {model_id}")
print(f"  Selected schedulers ({len(selected_schedulers)}): {', '.join(selected_schedulers)}")
print(f"  Total images to generate: {num_of_images} × {len(selected_schedulers)} = {num_of_images * len(selected_schedulers)}")
print(f"  Guidance: {guidance_low}-{guidance_high}")
print(f"  Steps: {steps_low}-{steps_high}")
print(f"  Resolution: {resolution_mode}")
print(f"  Uniform beta schedule: {use_uniform}")
print(f"  Save to: {save_directory}")

✓ Configuration loaded
  Model: John6666/easy-sfw-nsfw-pony-cartooneasyv1-sdxl
  Selected schedulers (3): Euler, KDPM2, DDIM
  Total images to generate: 50 × 3 = 150
  Guidance: 5-15
  Steps: 25-40
  Resolution: sdxl
  Uniform beta schedule: False
  Save to: /content/drive/MyDrive/AI/images/20260101231503/


## Prompt Management System

In [15]:
class PromptManager:
    """Manages prompt fetching, caching, and variation generation."""

    def __init__(self, api_url_template, cache_size=20):
        self.api_url_template = api_url_template
        self.cache = []
        self.cache_size = cache_size
        self.stats = defaultdict(int)

    def fetch_prompts(self, count=5):
        """Fetch multiple prompts at once and cache them."""
        prompts = []
        for _ in range(count):
            try:
                response = requests.get(self.api_url_template, timeout=10)
                if response.status_code == 200:
                    data = response.json()

                    # Parse overflow - API returns comma-separated string, not array
                    overflow_raw = data.get('overflow', '')
                    if isinstance(overflow_raw, str):
                        # Split by comma and strip whitespace
                        overflow = [item.strip() for item in overflow_raw.split(',') if item.strip()]
                    elif isinstance(overflow_raw, list):
                        # Already a list (in case API changes)
                        overflow = overflow_raw
                    else:
                        overflow = []

                    prompts.append({
                        'prompt': data['prompt'],
                        'overflow': overflow,
                        'refined': data.get('refined', ''),
                        'negative_prompt': data.get('negative_prompt', NEGATIVE_PROMPT)  # Use API's negative or fallback
                    })
                    self.stats['fetched'] += 1
                else:
                    print(f"Failed to fetch prompt: {response.status_code}")
                    self.stats['failed'] += 1
            except Exception as e:
                print(f"Error fetching prompt: {e}")
                self.stats['failed'] += 1

        self.cache.extend(prompts)
        # Keep cache size manageable
        if len(self.cache) > self.cache_size:
            self.cache = self.cache[-self.cache_size:]

        return len(prompts)

    def get_prompt(self):
        """Get a prompt from cache or fetch new ones if cache is low."""
        if len(self.cache) < 3:
            print(f"Cache low ({len(self.cache)} prompts), fetching more...")
            self.fetch_prompts(prompts_per_batch)

        if not self.cache:
            # Emergency fetch
            self.fetch_prompts(1)

        if self.cache:
            self.stats['used'] += 1
            return self.cache.pop(0)
        return None

    def build_weighted_prompt(self, prompt_data):
        """
        Build weighted prompt using all overflow tags.

        Args:
            prompt_data: Dict with 'prompt', 'overflow', 'refined', 'negative_prompt'

        Returns:
            tuple: (final_prompt, overflow_count)
        """
        prompt = prompt_data['prompt']
        overflow = prompt_data.get('overflow', [])

        if not overflow:
            return prompt, 0

        # Use ALL overflow tags (API provides them for a reason)
        overflow_text = ", ".join(overflow)
        final_prompt = f"{prompt}, {overflow_text}"

        return final_prompt, len(overflow)

    def print_stats(self):
        print(f"\nPrompt Manager Stats:")
        print(f"  Fetched: {self.stats['fetched']}")
        print(f"  Used: {self.stats['used']}")
        print(f"  Failed: {self.stats['failed']}")
        print(f"  Cached: {len(self.cache)}")

# Initialize prompt manager
prompt_url = f"https://prompt-gen.squigglypickle.co.uk/generate-prompt?sfw={sfw}&base_prompt={base_prompt}&selected_categories={' '.join(selected_categories)}"
prompt_manager = PromptManager(prompt_url)

# Pre-fetch prompts
print(f"Pre-fetching {prompts_per_batch} prompts...")
fetched = prompt_manager.fetch_prompts(prompts_per_batch)
print(f"Fetched {fetched} prompts successfully")

Pre-fetching 2 prompts...
Fetched 2 prompts successfully


In [None]:
# Check if model needs to be downloaded
model_exists = os.path.exists(model_path) and os.path.isdir(model_path) and len(os.listdir(model_path)) > 0
should_download = download_model or not model_exists

if should_download:
    if not model_exists:
        print(f"Model not found at {model_path}, downloading from HuggingFace...")
    elif download_model:
        print(f"Re-downloading model {model_id} (download_model checkbox enabled)...")

    if not huggingface_token:
        print("Warning: No HuggingFace token provided. Download may fail for gated models.")

    snapshot_download(
        repo_id=model_id,
        local_dir=model_path,
        token=huggingface_token if huggingface_token else None,
        ignore_patterns=["*.safetensors.lock"]
    )
    print(f"✓ Model downloaded to: {model_path}")
else:
    print(f"✓ Model already exists at {model_path}, skipping download")

# Load pipeline with memory optimizations
print(f"Loading model {model_id}...")

# Base load kwargs
load_kwargs = {
    "torch_dtype": torch.float16,
    "use_safetensors": True,
}

# Try loading with fp16 variant first, fallback to no variant if unavailable
pipe = None
for attempt in ["fp16", "no_variant"]:
    try:
        if attempt == "fp16":
            load_kwargs["variant"] = "fp16"
            print("Attempting to load with fp16 variant...")
        else:
            load_kwargs.pop("variant", None)
            print("Retrying without variant (using full precision weights)...")

        # Try loading from local path first, fall back to HuggingFace if corrupted
        try:
            if os.path.exists(model_path):
                pipe = StableDiffusionXLPipeline.from_pretrained(model_path, **load_kwargs)
                print(f"✓ Loaded from local path: {model_path}")
            else:
                raise FileNotFoundError("Local model path does not exist")
        except (OSError, FileNotFoundError) as local_error:
            # Local model is corrupted or missing, load from HuggingFace
            if os.path.exists(model_path):
                print(f"⚠ Local model corrupted: {local_error}")
                print(f"Loading from HuggingFace instead...")
            load_kwargs["token"] = huggingface_token if huggingface_token else None
            pipe = StableDiffusionXLPipeline.from_pretrained(model_id, **load_kwargs)
            print(f"✓ Loaded from HuggingFace: {model_id}")

        print(f"✓ Model loaded successfully ({attempt})")
        break

    except ValueError as e:
        if "variant" in str(e) and attempt == "fp16":
            # fp16 variant not available, will retry without variant
            continue
        else:
            # Some other ValueError, re-raise it
            raise
    except Exception as e:
        print(f"Error loading model with {attempt}: {e}")
        if attempt == "no_variant":
            # Last attempt failed, re-raise
            raise

if pipe is None:
    raise RuntimeError("Failed to load model after all attempts")

# NOW load the enhanced VAE (after pipeline is loaded)
# This ensures we can properly replace the model's default VAE
print("\nLoading enhanced VAE for better color accuracy...")
from diffusers import AutoencoderKL

try:
    vae = AutoencoderKL.from_pretrained(
        "madebyollin/sdxl-vae-fp16-fix",
        torch_dtype=torch.float16
    )
    print("✓ Enhanced VAE loaded from HuggingFace")

    # Replace the pipeline's VAE with the enhanced one
    pipe.vae = vae
    print("✓ Enhanced VAE applied to pipeline")

except Exception as e:
    print(f"⚠ Could not load enhanced VAE, using model's default: {e}")

# Move pipeline to GPU
pipe = pipe.to("cuda")

# Memory optimizations
print("\nApplying memory optimizations...")
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

# Try to enable xformers for better memory efficiency
try:
    pipe.enable_xformers_memory_efficient_attention()
    print("✓ xformers enabled")
except:
    print("⚠ xformers not available")

# NOTE: enable_model_cpu_offload() is DISABLED because it conflicts with Compel
# Compel needs text encoders to stay on GPU for prompt weighting
print("⚠ CPU offload disabled (incompatible with Compel)")

# Initialize CompelForSDXL for advanced prompt weighting
# This wrapper automatically handles SDXL's dual text encoders and pooled embeddings
compel = CompelForSDXL(pipe)
print("✓ CompelForSDXL initialized for advanced prompt weighting")

print(f"\n✓ Pipeline ready!")
if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

✓ Model already exists at /content/drive/MyDrive/AI/models/John6666/easy-sfw-nsfw-pony-cartooneasyv1-sdxl, skipping download
Loading model John6666/easy-sfw-nsfw-pony-cartooneasyv1-sdxl...
Attempting to load with fp16 variant...
Retrying without variant (using full precision weights)...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

✓ Loaded from local path: /content/drive/MyDrive/AI/models/John6666/easy-sfw-nsfw-pony-cartooneasyv1-sdxl
✓ Model loaded successfully (no_variant)

Loading enhanced VAE for better color accuracy...
✓ Enhanced VAE loaded from HuggingFace
✓ Enhanced VAE applied to pipeline


# Sampler Configuration (simplified)

In [None]:
# Sampler classes
from diffusers import (
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    KDPM2DiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    HeunDiscreteScheduler,
    LMSDiscreteScheduler,
    DDIMScheduler,
    PNDMScheduler,
    UniPCMultistepScheduler
)

# Apply uniform beta_schedule for models that require it
if use_uniform:
    print(f"Applying uniform beta_schedule for {model_id}...")
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.scheduler.config.beta_schedule = "uniform"
    print(f"✓ Scheduler set to DPMSolverMultistepScheduler with uniform beta_schedule")

# IMPORTANT: Initialize samplers AFTER pipeline is loaded so we can use the model's scheduler config
# This preserves the model's trained scheduler settings (beta schedules, timestep spacing, etc.)

def initialize_samplers(base_scheduler_config):
    """
    Create scheduler instances using the model's trained config.
    This ensures we inherit important settings like beta_schedule, timestep_spacing, etc.

    IMPORTANT: For Karras variants, we use a clean config to avoid conflicts between
    beta_schedule settings and Karras sigmas (which override beta schedules anyway).
    This prevents duplicate timestep bugs that cause index out of bounds errors.
    """
    # Create a clean config for Karras variants (without beta_schedule that might conflict)
    clean_config = {k: v for k, v in base_scheduler_config.items() if k not in ['beta_schedule']}

    return {
        # DPM++ variants (multistep)
        "DPM++ 2M": lambda: DPMSolverMultistepScheduler.from_config(
            base_scheduler_config
        ),
        "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
            clean_config,
            use_karras_sigmas=True
        ),

        # DPM++ SDE variants (singlestep - stochastic)
        "DPM++ SDE": lambda: DPMSolverSinglestepScheduler.from_config(
            base_scheduler_config
        ),
        "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
            clean_config,
            use_karras_sigmas=True
        ),

        # Euler variants
        "Euler": lambda: EulerDiscreteScheduler.from_config(
            base_scheduler_config
        ),
        "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
            base_scheduler_config
        ),

        # KDPM2 variants
        "KDPM2": lambda: KDPM2DiscreteScheduler.from_config(
            base_scheduler_config
        ),
        "KDPM2 a": lambda: KDPM2AncestralDiscreteScheduler.from_config(
            base_scheduler_config
        ),

        # Other popular schedulers
        "Heun": lambda: HeunDiscreteScheduler.from_config(
            base_scheduler_config
        ),
        "LMS": lambda: LMSDiscreteScheduler.from_config(
            base_scheduler_config
        ),
        "DDIM": lambda: DDIMScheduler.from_config(
            base_scheduler_config
        ),
        "PNDM": lambda: PNDMScheduler.from_config(
            base_scheduler_config
        ),
        "UniPC": lambda: UniPCMultistepScheduler.from_config(
            base_scheduler_config
        ),

        # Model's original scheduler
        "Model Default": lambda: pipe.scheduler.__class__.from_config(
            base_scheduler_config
        )
    }

# Initialize samplers with the model's scheduler config
SAMPLERS = initialize_samplers(pipe.scheduler.config)

def get_resolution():
    """Get random resolution based on mode."""
    if resolution_mode == "classic":
        width = random.choice(CLASSIC_RESOLUTIONS)
        height = random.choice(CLASSIC_RESOLUTIONS)
        return (width, height)
    else:  # sdxl
        return random.choice(SDXL_RESOLUTIONS)

print(f"✓ Samplers configured using model's scheduler config")
print(f"  Base scheduler: {pipe.scheduler.__class__.__name__}")
print(f"  Available samplers: {len(SAMPLERS)}")
print(f"  Config: beta_schedule={pipe.scheduler.config.get('beta_schedule', 'N/A')}, "
      f"timestep_spacing={pipe.scheduler.config.get('timestep_spacing', 'N/A')}")
print(f"  Karras variants use clean config (no beta_schedule) to prevent timestep conflicts")

## Scheduler Comparison Test (Optional)

Test all schedulers with identical parameters to compare quality and characteristics.

In [None]:
#@markdown ### Scheduler Comparison Settings
run_sampler_test = False #@param {type:"boolean"}
test_prompt = "a beautiful woman with long flowing hair, photorealistic, high quality, detailed" #@param {type:"string"}
test_seed = 42 #@param {type:"integer"}
test_width = 1024 #@param {type:"integer"}
test_height = 1024 #@param {type:"integer"}
test_steps = 30 #@param {type:"integer"}
test_guidance = 7.5 #@param {type:"number"}
test_columns = 3 #@param {type:"integer"}

if run_sampler_test:
    print("=" * 60)
    print("SCHEDULER COMPARISON TEST")
    print("=" * 60)
    print(f"\nTest Parameters:")
    print(f"  Prompt: {test_prompt[:60]}...")
    print(f"  Seed: {test_seed}")
    print(f"  Resolution: {test_width}x{test_height}")
    print(f"  Steps: {test_steps}")
    print(f"  Guidance: {test_guidance}")
    print(f"\nGenerating with {len(SAMPLERS)} schedulers...")

    # Create test directory
    test_directory = f"{base_path}scheduler_tests/{datetime.now().strftime('%Y%m%d%H%M%S')}/"
    os.makedirs(test_directory, exist_ok=True)

    # Encode prompt once (use for all schedulers)
    # IMPORTANT: Pass BOTH positive and negative prompts to compel() together
    print("\nEncoding prompt...")
    conditioning = compel(test_prompt, negative_prompt=NEGATIVE_PROMPT)

    # Generate with each scheduler
    test_results = []
    for scheduler_name in sorted(SAMPLERS.keys()):
        try:
            print(f"\n[{len(test_results)+1}/{len(SAMPLERS)}] Testing: {scheduler_name}")

            # Set scheduler
            pipe.scheduler = SAMPLERS[scheduler_name]()

            # Generate image
            # IMPORTANT: Must pass all 4 embedding parameters when using compel
            result = pipe(
                prompt_embeds=conditioning.embeds,
                pooled_prompt_embeds=conditioning.pooled_embeds,
                negative_prompt_embeds=conditioning.negative_embeds,
                negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds,
                num_images_per_prompt=1,
                width=test_width,
                height=test_height,
                guidance_scale=test_guidance,
                num_inference_steps=test_steps,
                generator=torch.Generator(device="cuda").manual_seed(test_seed),
            ).images[0]

            # Save image
            filename = f"{scheduler_name.replace(' ', '_').replace('+', 'p')}.png"
            filepath = os.path.join(test_directory, filename)
            result.save(filepath)

            test_results.append({
                'scheduler': scheduler_name,
                'image': result,
                'filename': filename
            })

            print(f"  ✓ Saved: {filename}")

            # Cleanup
            del result
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"  ✗ Failed: {e}")
            continue

    # Cleanup encoded prompt
    del conditioning
    torch.cuda.empty_cache()

    # Display results in grid
    if test_results:
        print(f"\n{'='*60}")
        print(f"RESULTS: {len(test_results)} schedulers tested")
        print(f"Saved to: {test_directory}")
        print(f"{'='*60}\n")

        # Create comparison grid
        num_results = len(test_results)
        num_rows = (num_results + test_columns - 1) // test_columns

        fig = plt.figure(figsize=(6 * test_columns, 6 * num_rows))

        for idx, result in enumerate(test_results):
            ax = plt.subplot(num_rows, test_columns, idx + 1)
            ax.imshow(result['image'])
            ax.axis('off')
            ax.set_title(result['scheduler'], fontsize=12, weight='bold')

        plt.tight_layout()
        plt.show()

        # Save metadata
        metadata = {
            'prompt': test_prompt,
            'negative_prompt': NEGATIVE_PROMPT,
            'seed': test_seed,
            'width': test_width,
            'height': test_height,
            'steps': test_steps,
            'guidance_scale': test_guidance,
            'model': model_id,
            'schedulers_tested': [r['scheduler'] for r in test_results],
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        }

        with open(os.path.join(test_directory, 'test_metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=2)

        print(f"\nℹ️  To compare specific schedulers, look at the images in:")
        print(f"   {test_directory}")
    else:
        print("\n✗ No results generated")
else:
    print("ℹ️  Scheduler comparison test disabled")
    print("   Set run_sampler_test = True to compare all schedulers with identical parameters")

# Image Generation

In [None]:
metadata_list = []
metadata_path = os.path.join(save_directory, "metadata.json")

print(f"Starting generation of {num_of_images} images with {len(selected_schedulers)} scheduler(s)...")
print(f"Total images: {num_of_images * len(selected_schedulers)}")
print(f"Save directory: {save_directory}\n")

image_counter = 0
for i in range(num_of_images):
    try:
        # Get prompt from manager
        prompt_data = prompt_manager.get_prompt()
        if not prompt_data:
            print(f"Failed to get prompt for image {i}, skipping...")
            continue

        # Build final prompt with ALL overflow tags
        final_prompt, overflow_count = prompt_manager.build_weighted_prompt(prompt_data)

        # Randomize parameters (same for all schedulers)
        width, height = get_resolution()
        guidance_scale = round(random.uniform(guidance_low, guidance_high) * 2) / 2
        num_steps = random.randint(steps_low, steps_high)

        # Use CompelForSDXL for prompt encoding (handles long prompts and SDXL dual encoders)
        # IMPORTANT: Pass BOTH positive and negative prompts to compel() together
        negative_prompt_text = prompt_data.get('negative_prompt', NEGATIVE_PROMPT)
        conditioning = compel(final_prompt, negative_prompt=negative_prompt_text)

        # Generate with each selected scheduler
        for scheduler_name in selected_schedulers:
            try:
                # Set scheduler
                pipe.scheduler = SAMPLERS[scheduler_name]()

                seed = random.randint(0, 2**32 - 1)

                # Generate image
                # IMPORTANT: Must pass all 4 embedding parameters when using compel
                result = pipe(
                    prompt_embeds=conditioning.embeds,
                    pooled_prompt_embeds=conditioning.pooled_embeds,
                    negative_prompt_embeds=conditioning.negative_embeds,
                    negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds,
                    num_images_per_prompt=1,
                    width=width,
                    height=height,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_steps,
                    generator=torch.Generator(device="cuda").manual_seed(seed),
                ).images[0]

                # Save image with scheduler name in filename
                scheduler_short = scheduler_name.replace(' ', '_').replace('+', 'p')
                filename = f"{datetime.now().strftime('%Y%m%d%H%M%S')}_{str(image_counter).zfill(4)}_{scheduler_short}.png"
                filepath = os.path.join(save_directory, filename)
                result.save(filepath)

                # Store metadata
                metadata = {
                    "filename": filename,
                    "model": model_id,
                    "scheduler": scheduler_name,
                    "base_prompt": base_prompt,
                    "prompt": prompt_data['prompt'],
                    "final_prompt": final_prompt,
                    "overflow_total": len(prompt_data.get('overflow', [])),
                    "overflow_used": overflow_count,
                    "negative_prompt": negative_prompt_text,
                    "sfw": sfw,
                    "seed": seed,
                    "width": width,
                    "height": height,
                    "guidance_scale": guidance_scale,
                    "num_steps": num_steps,
                    "prompt_index": i
                }
                metadata_list.append(metadata)

                # Save metadata after each image (in case of crash/failure)
                with open(metadata_path, 'w') as f:
                    json.dump(metadata_list, f, indent=2)

                # Memory cleanup
                del result
                torch.cuda.empty_cache()

                image_counter += 1

            except Exception as e:
                print(f"Error with scheduler {scheduler_name}: {e}")
                continue

        # Progress update
        overflow_info = f" +{overflow_count} overflow" if overflow_count > 0 else ""
        schedulers_info = f"{len(selected_schedulers)} schedulers" if len(selected_schedulers) > 1 else selected_schedulers[0]
        print(f"[{i+1}/{num_of_images}] {width}x{height} | {num_steps} steps | G:{guidance_scale:.1f} | {schedulers_info}{overflow_info}")

        # Cleanup prompt embeddings
        del conditioning
        torch.cuda.empty_cache()
        gc.collect()

    except Exception as e:
        print(f"Error generating image {i}: {e}")
        import traceback
        traceback.print_exc()
        continue

print(f"\n✓ Generation complete!")
print(f"✓ {len(metadata_list)} images saved to: {save_directory}")
print(f"✓ Metadata saved to: {metadata_path}")

# Print stats
prompt_manager.print_stats()

if torch.cuda.is_available():
    print(f"\nFinal GPU Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

## Statistics and Analysis

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

if not run_sampler_test and 'metadata_list' in locals() and metadata_list:
    # Sampler distribution
    samplers = [m['sampler'] for m in metadata_list]
    sampler_counts = Counter(samplers)

    # Resolution distribution
    resolutions = [f"{m['width']}x{m['height']}" for m in metadata_list]
    resolution_counts = Counter(resolutions)

    # Overflow usage
    avg_overflow_used = sum(m['overflow_used'] for m in metadata_list) / len(metadata_list)
    avg_overflow_total = sum(m['overflow_total'] for m in metadata_list) / len(metadata_list)

    print("\n=== Generation Statistics ===")
    print(f"\nSampler Usage:")
    for sampler, count in sampler_counts.most_common():
        print(f"  {sampler}: {count} ({count/len(metadata_list)*100:.1f}%)")

    print(f"\nTop 5 Resolutions:")
    for res, count in resolution_counts.most_common(5):
        print(f"  {res}: {count} ({count/len(metadata_list)*100:.1f}%)")

    print(f"\nOverflow Usage:")
    print(f"  Average used: {avg_overflow_used:.1f}")
    print(f"  Average total: {avg_overflow_total:.1f}")
    print(f"  Usage rate: {avg_overflow_used/avg_overflow_total*100:.1f}%" if avg_overflow_total > 0 else "  N/A")
elif run_sampler_test:
    print("ℹ️  Statistics skipped (sampler test mode - see test results above)")
else:
    print("ℹ️  No metadata available for statistics")

## Display Results

In [None]:
#@markdown Display generated images in a grid
show_results = True #@param {type:"boolean"}
num_columns = 3 #@param {type:"integer"}
max_images_to_show = 12 #@param {type:"integer"}

if not run_sampler_test and show_results and 'metadata_list' in locals() and metadata_list:
    display_metadata = metadata_list[:max_images_to_show]
    num_images = len(display_metadata)
    num_rows = (num_images + num_columns - 1) // num_columns

    plt.figure(figsize=(5 * num_columns, 5 * num_rows))

    for idx, metadata in enumerate(display_metadata):
        filepath = os.path.join(save_directory, metadata['filename'])
        img = Image.open(filepath)

        plt.subplot(num_rows, num_columns, idx + 1)
        plt.imshow(img)
        plt.axis('off')

        # Title with key info
        title = f"{metadata['sampler']}\n{metadata['width']}x{metadata['height']} | {metadata['num_steps']} steps"
        plt.title(title, fontsize=8)

    plt.tight_layout()
    plt.show()

    print(f"Displayed {num_images} of {len(metadata_list)} images")
elif run_sampler_test:
    print("ℹ️  Display skipped (sampler test mode - results shown above)")
elif not show_results:
    print("ℹ️  Display disabled (set show_results = True to enable)")
else:
    print("ℹ️  No images to display")

## Cleanup

In [None]:
#@markdown Clean up and optionally end session
end_session = False #@param {type:"boolean"}

# Memory cleanup
del pipe, compel
torch.cuda.empty_cache()
gc.collect()

print("Cleanup complete")

if torch.cuda.is_available():
    print(f"GPU Memory after cleanup: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

if end_session:
    print("Ending Colab session...")
    sys.exit()