# Image Generator with Compel

This notebook enhances the basic generator with:
- **Compel Integration**: Handle long prompts with weighted embeddings (no 77 token limit)
- **Batch Prompt Fetching**: Pre-fetch multiple prompts to reduce API calls
- **Overflow Management**: Smart usage of overflow tags from prompt API
- **Memory Optimized**: Works on Colab free tier

In [1]:
import os
from dotenv import load_dotenv
from google.colab import drive

mount_path = '/content/drive'

if not os.path.exists(mount_path):
    print("Drive not mounted. Mounting now...")
    drive.mount(mount_path)
else:
    print("Drive already mounted.")

# Define the path to your file
file_path = '/content/drive/MyDrive/hf_token.env'

# Load the environment variables
load_dotenv(file_path)

# Access the variable (replace 'HF_TOKEN' with the actual key name in your file)
hf_token = os.getenv('HF_TOKEN')

# Suppress warnings
import warnings
warnings.filterwarnings('ignore', message='Flax classes are deprecated')
warnings.filterwarnings('ignore', category=FutureWarning)

NameError: name 'mount_path' is not defined

In [None]:
# Install required packages
!pip install -q diffusers transformers accelerate safetensors omegaconf invisible-watermark compel bitsandbytes

In [None]:
import requests
from datetime import datetime
import sys
import os
import gc
from PIL import Image
from io import BytesIO
import random
import json
from collections import defaultdict
import torch
from huggingface_hub import snapshot_download, login
from diffusers import (
    StableDiffusionXLPipeline,
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    HeunDiscreteScheduler,
    UniPCMultistepScheduler
)
from compel import Compel, ReturnedEmbeddingsType

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Initial GPU Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

## Configuration

# Sampler selection removed - use settings in next cell instead

In [None]:
# This cell has been removed - sampler selection now in main config

In [None]:
#@markdown ### Paths and Authentication
base_path = "/content/drive/MyDrive/AI/" #@param {type:"string"}

#@markdown ### Model Configuration
model_id = "Vvilams/pony-realism-v21main-sdxl" #@param ["stablediffusionapi/mklan-xxx-nsfw-pony","stablediffusionapi/duchaiten-real3d-nsfw-xl", "John6666/cyberrealistic-pony-v7-sdxl", "John6666/uber-realistic-porn-merge-xl-urpmxl-v6final-sdxl","John6666/fucktastic-real-checkpoint-pony-pdxl-porn-realistic-nsfw-sfw-21-sdxl","Vvilams/pony-realism-v21main-sdxl","John6666/sexoholic-real-pony-nsfw-v2-sdxl","John6666/wai-ani-nsfw-ponyxl-v11-sdxl", "John6666/duchaiten-pony-real-v20-sdxl", "stable-diffusion-v1-5/stable-diffusion-v1-5", "stabilityai/stable-diffusion-xl-base-1.0", "John6666/pornworks-real-porn-v03-sdxl", "UnfilteredAI/NSFW-GEN-ANIME", "UnfilteredAI/NSFW-gen-v2"]
download_model = False #@param {type:"boolean"}

#@markdown ### Prompt Configuration
base_prompt = "" #@param {type:"string"}
sfw = False #@param {type:"boolean"}
selected_categories = [] #@param {type:"raw"}
prompts_per_batch = 5 #@param {type:"integer"}
overflow_usage = 0.6 #@param {type:"slider", min:0.0, max:1.0, step:0.1}

#@markdown ### Generation Settings
num_of_images = 50 #@param {type:"integer"}
sampler = "DPM++ 2M Karras" #@param ["DPM++ 2M", "DPM++ 2M Karras", "Euler a", "Heun", "KDPM2 a", "UniPC"]
guidance_low = 5 #@param {type:"integer"}
guidance_high = 15 #@param {type:"integer"}
steps_low = 25 #@param {type:"integer"}
steps_high = 40 #@param {type:"integer"}

#@markdown ### Resolution Settings
resolution_mode = "classic" #@param ["classic", "sdxl"]

# Resolution pools
CLASSIC_RESOLUTIONS = [720, 768, 800, 1024]
SDXL_RESOLUTIONS = [
    (1024, 1024), (1152, 896), (896, 1152), (1216, 832),
    (832, 1216), (1344, 768), (768, 1344)
]

# Derived paths
model_path = base_path + "models/" + model_id
save_directory = f"{base_path}images/{datetime.now().strftime('%Y%m%d%H%M%S')}/"
os.makedirs(save_directory, exist_ok=True)

# Models requiring uniform scheduler
uniform_models = [
    "John6666/uber-realistic-porn-merge-xl-urpmxl-v6final-sdxl",
    "John6666/sexoholic-real-pony-nsfw-v2-sdxl",
    "John6666/duchaiten-pony-real-v20-sdxl",
    "stabilityai/sdxl-turbo",
    "John6666/pornworks-real-porn-v03-sdxl"
]
use_uniform = model_id in uniform_models

# Negative prompt (simple and proven)
NEGATIVE_PROMPT = "text, writing, bad teeth, deformed face and eyes, child, childish, young, deformed, uneven eyes, too many fingers"

print(f"✓ Configuration loaded")
print(f"  Model: {model_id}")
print(f"  Sampler: {sampler}")
print(f"  Guidance: {guidance_low}-{guidance_high}")
print(f"  Steps: {steps_low}-{steps_high}")
print(f"  Resolution: {resolution_mode}")
print(f"  Overflow usage: {overflow_usage*100:.0f}%")
print(f"  Save to: {save_directory}")

## Prompt Management System

In [None]:
class PromptManager:
    """Manages prompt fetching, caching, and variation generation."""

    def __init__(self, api_url_template, cache_size=20):
        self.api_url_template = api_url_template
        self.cache = []
        self.cache_size = cache_size
        self.stats = defaultdict(int)

    def fetch_prompts(self, count=5):
        """Fetch multiple prompts at once and cache them."""
        prompts = []
        for _ in range(count):
            try:
                response = requests.get(self.api_url_template, timeout=10)
                if response.status_code == 200:
                    data = response.json()

                    # Parse overflow - API returns comma-separated string, not array
                    overflow_raw = data.get('overflow', '')
                    if isinstance(overflow_raw, str):
                        # Split by comma and strip whitespace
                        overflow = [item.strip() for item in overflow_raw.split(',') if item.strip()]
                    elif isinstance(overflow_raw, list):
                        # Already a list (in case API changes)
                        overflow = overflow_raw
                    else:
                        overflow = []

                    prompts.append({
                        'prompt': data['prompt'],
                        'overflow': overflow,
                        'refined': data.get('refined', '')
                    })
                    self.stats['fetched'] += 1
                else:
                    print(f"Failed to fetch prompt: {response.status_code}")
                    self.stats['failed'] += 1
            except Exception as e:
                print(f"Error fetching prompt: {e}")
                self.stats['failed'] += 1

        self.cache.extend(prompts)
        # Keep cache size manageable
        if len(self.cache) > self.cache_size:
            self.cache = self.cache[-self.cache_size:]

        return len(prompts)

    def get_prompt(self):
        """Get a prompt from cache or fetch new ones if cache is low."""
        if len(self.cache) < 3:
            print(f"Cache low ({len(self.cache)} prompts), fetching more...")
            self.fetch_prompts(prompts_per_batch)

        if not self.cache:
            # Emergency fetch
            self.fetch_prompts(1)

        if self.cache:
            self.stats['used'] += 1
            return self.cache.pop(0)
        return None

    def build_weighted_prompt(self, prompt_data, overflow_usage=0.6):
        """
        Build weighted prompt with smart overflow usage.

        Args:
            prompt_data: Dict with 'prompt', 'overflow', 'refined'
            overflow_usage: Fraction of overflow items to use (0.0-1.0)
        """
        prompt = prompt_data['prompt']
        overflow = prompt_data.get('overflow', [])

        if not overflow:
            return prompt, 0

        # Select random subset of overflow
        count = max(1, int(len(overflow) * random.uniform(overflow_usage * 0.7, overflow_usage * 1.3)))
        count = min(count, len(overflow))
        selected = random.sample(overflow, k=count)

        # Build final prompt
        overflow_text = ", ".join(selected)
        final_prompt = f"{prompt}, {overflow_text}"

        return final_prompt, count

    def print_stats(self):
        print(f"\nPrompt Manager Stats:")
        print(f"  Fetched: {self.stats['fetched']}")
        print(f"  Used: {self.stats['used']}")
        print(f"  Failed: {self.stats['failed']}")
        print(f"  Cached: {len(self.cache)}")

# Initialize prompt manager
prompt_url = f"https://prompt-gen.squigglypickle.co.uk/generate-prompt?sfw={sfw}&base_prompt={base_prompt}&selected_categories={' '.join(selected_categories)}"
prompt_manager = PromptManager(prompt_url)

# Pre-fetch prompts
print(f"Pre-fetching {prompts_per_batch} prompts...")
fetched = prompt_manager.fetch_prompts(prompts_per_batch)
print(f"Fetched {fetched} prompts successfully")

## Advanced Negative Prompt Generator

In [None]:
def generate_dynamic_negative(prompt, sfw=False):
    """
    Generate context-aware negative prompts.
    Adapts to content type and randomizes quality negatives.
    """
    # Core quality negatives (always included)
    base = "low quality, worst quality, normal quality, blurry, hazy, out of focus"

    # Detect style from prompt
    prompt_lower = prompt.lower()
    is_photo = any(kw in prompt_lower for kw in ["photo", "photograph", "realistic", "photorealistic"])
    is_artistic = any(kw in prompt_lower for kw in ["painting", "drawing", "sketch", "anime", "cartoon", "illustration"])
    is_portrait = any(kw in prompt_lower for kw in ["face", "portrait", "person", "woman", "man", "girl", "boy"])

    # Style-specific negatives
    if is_photo:
        base += ", cartoon, anime, painting, drawing, 3d render, cgi, fake, artificial"
    elif is_artistic:
        base += ", photograph, photo, realistic"

    # Quality negatives pool (select 3-5)
    quality_pool = [
        "bad quality", "bad anatomy", "bad proportions", "deformed",
        "poorly drawn", "ugly", "distorted", "mutation",
        "disfigured", "malformed", "mutated"
    ]
    base += ", " + ", ".join(random.sample(quality_pool, k=random.randint(3, 5)))

    # Portrait-specific negatives
    if is_portrait:
        portrait_negatives = [
            "bad hands", "poorly drawn hands", "bad fingers", "extra fingers",
            "missing fingers", "fused fingers", "bad eyes", "asymmetric eyes",
            "crossed eyes", "bad face", "poorly drawn face", "asymmetric face"
        ]
        base += ", " + ", ".join(random.sample(portrait_negatives, k=random.randint(4, 6)))

    # Composition negatives (select 2-3)
    composition_pool = [
        "cropped", "out of frame", "bad composition", "uncentered",
        "duplicate", "watermark", "text", "signature"
    ]
    base += ", " + ", ".join(random.sample(composition_pool, k=random.randint(2, 3)))

    # SFW-specific negatives
    if sfw:
        base += ", nsfw, nude, nudity, explicit, sexual content, naked"

    return base

## Model Loading with Memory Optimization

In [None]:
# Check if model needs to be downloaded
model_exists = os.path.exists(model_path) and os.path.isdir(model_path) and len(os.listdir(model_path)) > 0
should_download = download_model or not model_exists

if should_download:
    if not model_exists:
        print(f"Model not found at {model_path}, downloading from HuggingFace...")
    elif download_model:
        print(f"Re-downloading model {model_id} (download_model checkbox enabled)...")

    if not huggingface_token:
        print("Warning: No HuggingFace token provided. Download may fail for gated models.")

    snapshot_download(
        repo_id=model_id,
        local_dir=model_path,
        token=huggingface_token if huggingface_token else None,
        ignore_patterns=["*.safetensors.lock"]
    )
    print(f"✓ Model downloaded to: {model_path}")
else:
    print(f"✓ Model already exists at {model_path}, skipping download")

# Load pipeline with memory optimizations
print(f"Loading model {model_id}...")

# Load the correct VAE for SDXL (fixes color/quality issues)
from diffusers import AutoencoderKL

print("Loading optimized VAE for better color accuracy...")
try:
    vae = AutoencoderKL.from_pretrained(
        "madebyollin/sdxl-vae-fp16-fix",
        torch_dtype=torch.float16
    )
    print("✓ Loaded fp16-fixed VAE for better quality")
except Exception as e:
    print(f"⚠ Could not load optimized VAE, using model default: {e}")
    vae = None

# Base load kwargs
load_kwargs = {
    "torch_dtype": torch.float16,
    "use_safetensors": True,
}

# Add VAE if successfully loaded
if vae is not None:
    load_kwargs["vae"] = vae

# Try loading with fp16 variant first, fallback to no variant if unavailable
pipe = None
for attempt in ["fp16", "no_variant"]:
    try:
        if attempt == "fp16":
            load_kwargs["variant"] = "fp16"
            print("Attempting to load with fp16 variant...")
        else:
            load_kwargs.pop("variant", None)
            print("Retrying without variant (using full precision weights)...")

        # Try loading from local path first, fall back to HuggingFace if corrupted
        try:
            if os.path.exists(model_path):
                pipe = StableDiffusionXLPipeline.from_pretrained(model_path, **load_kwargs)
                print(f"✓ Loaded from local path: {model_path}")
            else:
                raise FileNotFoundError("Local model path does not exist")
        except (OSError, FileNotFoundError) as local_error:
            # Local model is corrupted or missing, load from HuggingFace
            if os.path.exists(model_path):
                print(f"⚠ Local model corrupted: {local_error}")
                print(f"Loading from HuggingFace instead...")
            load_kwargs["token"] = huggingface_token if huggingface_token else None
            pipe = StableDiffusionXLPipeline.from_pretrained(model_id, **load_kwargs)
            print(f"✓ Loaded from HuggingFace: {model_id}")

        print(f"✓ Model loaded successfully ({attempt})")
        break

    except ValueError as e:
        if "variant" in str(e) and attempt == "fp16":
            # fp16 variant not available, will retry without variant
            continue
        else:
            # Some other ValueError, re-raise it
            raise
    except Exception as e:
        print(f"Error loading model with {attempt}: {e}")
        if attempt == "no_variant":
            # Last attempt failed, re-raise
            raise

if pipe is None:
    raise RuntimeError("Failed to load model after all attempts")

# Move pipeline to GPU
pipe = pipe.to("cuda")

# Memory optimizations
print("Applying memory optimizations...")
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

# Try to enable xformers for better memory efficiency
try:
    pipe.enable_xformers_memory_efficient_attention()
    print("✓ xformers enabled")
except:
    print("⚠ xformers not available")

# NOTE: enable_model_cpu_offload() is DISABLED because it conflicts with Compel
# Compel needs text encoders to stay on GPU for prompt weighting
print("⚠ CPU offload disabled (incompatible with Compel)")

# Initialize Compel for advanced prompt weighting
compel = Compel(
    tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
    text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True],
    device="cuda"
)
print("✓ Compel initialized for advanced prompt weighting")

print(f"\n✓ Pipeline ready!")
if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

# Sampler Configuration (simplified)

In [None]:
# Sampler classes
from diffusers import (
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    HeunDiscreteScheduler,
    UniPCMultistepScheduler
)

# Sampler factory functions - creates fresh scheduler instances
SAMPLERS = {
    "DPM++ 2M": lambda: DPMSolverMultistepScheduler(),
    "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler(use_karras_sigmas=True),
    "Euler a": lambda: EulerAncestralDiscreteScheduler(),
    "Heun": lambda: HeunDiscreteScheduler(),
    "KDPM2 a": lambda: KDPM2AncestralDiscreteScheduler(),
    "UniPC": lambda: UniPCMultistepScheduler()
}

def get_resolution():
    """Get random resolution based on mode."""
    if resolution_mode == "classic":
        width = random.choice(CLASSIC_RESOLUTIONS)
        height = random.choice(CLASSIC_RESOLUTIONS)
        return (width, height)
    else:  # sdxl
        return random.choice(SDXL_RESOLUTIONS)

print(f"✓ Samplers configured: {len(SAMPLERS)} available")

# Image Generation

In [None]:
# Sampler testing removed - use settings above for generation

# Generation cell below

In [None]:
metadata_list = []

# Setup sampler with clean initialization
print(f"Configuring sampler: {sampler}")
pipe.scheduler = SAMPLERS[sampler]()
print(f"✓ Scheduler initialized: {pipe.scheduler.__class__.__name__}")

print(f"\nStarting generation of {num_of_images} images...")
print(f"Save directory: {save_directory}\n")

for i in range(num_of_images):
    try:
        # Get prompt from manager
        prompt_data = prompt_manager.get_prompt()
        if not prompt_data:
            print(f"Failed to get prompt for image {i}, skipping...")
            continue

        # Build final prompt with overflow
        final_prompt, overflow_count = prompt_manager.build_weighted_prompt(
            prompt_data, overflow_usage
        )

        # Randomize parameters
        seed = random.randint(0, 2**32 - 1)
        width, height = get_resolution()
        guidance_scale = round(random.uniform(guidance_low, guidance_high) * 2) / 2
        num_steps = random.randint(steps_low, steps_high)

        # Use Compel for prompt encoding (handles long prompts)
        conditioning, pooled = compel(final_prompt)

        # Generate image
        result = pipe(
            prompt_embeds=conditioning,
            pooled_prompt_embeds=pooled,
            negative_prompt=NEGATIVE_PROMPT,
            num_images_per_prompt=1,
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_steps,
            generator=torch.Generator(device="cuda").manual_seed(seed),
        ).images[0]

        # Save image
        filename = f"{datetime.now().strftime('%Y%m%d%H%M%S')}_{str(i).zfill(3)}.png"
        filepath = os.path.join(save_directory, filename)
        result.save(filepath)

        # Store metadata
        metadata = {
            "filename": filename,
            "model": model_id,
            "sampler": sampler,
            "base_prompt": base_prompt,
            "prompt": prompt_data['prompt'],
            "final_prompt": final_prompt,
            "overflow_total": len(prompt_data.get('overflow', [])),
            "overflow_used": overflow_count,
            "negative_prompt": NEGATIVE_PROMPT,
            "sfw": sfw,
            "seed": seed,
            "width": width,
            "height": height,
            "guidance_scale": guidance_scale,
            "num_steps": num_steps
        }
        metadata_list.append(metadata)

        # Progress update
        overflow_info = f" +{overflow_count} overflow" if overflow_count > 0 else ""
        print(f"[{i+1}/{num_of_images}] {width}x{height} | {num_steps} steps | G:{guidance_scale:.1f}{overflow_info}")

        # Memory cleanup
        del result, conditioning, pooled
        torch.cuda.empty_cache()
        gc.collect()

    except Exception as e:
        print(f"Error generating image {i}: {e}")
        import traceback
        traceback.print_exc()
        continue

# Save metadata
metadata_path = os.path.join(save_directory, "metadata.json")
with open(metadata_path, 'w') as f:
    json.dump(metadata_list, f, indent=2)

print(f"\n✓ Generation complete!")
print(f"✓ {len(metadata_list)} images saved to: {save_directory}")
print(f"✓ Metadata saved to: {metadata_path}")

# Print stats
prompt_manager.print_stats()

if torch.cuda.is_available():
    print(f"\nFinal GPU Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

## Statistics and Analysis

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

if not run_sampler_test and 'metadata_list' in locals() and metadata_list:
    # Sampler distribution
    samplers = [m['sampler'] for m in metadata_list]
    sampler_counts = Counter(samplers)

    # Resolution distribution
    resolutions = [f"{m['width']}x{m['height']}" for m in metadata_list]
    resolution_counts = Counter(resolutions)

    # Overflow usage
    avg_overflow_used = sum(m['overflow_used'] for m in metadata_list) / len(metadata_list)
    avg_overflow_total = sum(m['overflow_total'] for m in metadata_list) / len(metadata_list)

    print("\n=== Generation Statistics ===")
    print(f"\nSampler Usage:")
    for sampler, count in sampler_counts.most_common():
        print(f"  {sampler}: {count} ({count/len(metadata_list)*100:.1f}%)")

    print(f"\nTop 5 Resolutions:")
    for res, count in resolution_counts.most_common(5):
        print(f"  {res}: {count} ({count/len(metadata_list)*100:.1f}%)")

    print(f"\nOverflow Usage:")
    print(f"  Average used: {avg_overflow_used:.1f}")
    print(f"  Average total: {avg_overflow_total:.1f}")
    print(f"  Usage rate: {avg_overflow_used/avg_overflow_total*100:.1f}%" if avg_overflow_total > 0 else "  N/A")
elif run_sampler_test:
    print("ℹ️  Statistics skipped (sampler test mode - see test results above)")
else:
    print("ℹ️  No metadata available for statistics")

## Display Results

In [None]:
#@markdown Display generated images in a grid
show_results = True #@param {type:"boolean"}
num_columns = 3 #@param {type:"integer"}
max_images_to_show = 12 #@param {type:"integer"}

if not run_sampler_test and show_results and 'metadata_list' in locals() and metadata_list:
    display_metadata = metadata_list[:max_images_to_show]
    num_images = len(display_metadata)
    num_rows = (num_images + num_columns - 1) // num_columns

    plt.figure(figsize=(5 * num_columns, 5 * num_rows))

    for idx, metadata in enumerate(display_metadata):
        filepath = os.path.join(save_directory, metadata['filename'])
        img = Image.open(filepath)

        plt.subplot(num_rows, num_columns, idx + 1)
        plt.imshow(img)
        plt.axis('off')

        # Title with key info
        title = f"{metadata['sampler']}\n{metadata['width']}x{metadata['height']} | {metadata['num_steps']} steps"
        plt.title(title, fontsize=8)

    plt.tight_layout()
    plt.show()

    print(f"Displayed {num_images} of {len(metadata_list)} images")
elif run_sampler_test:
    print("ℹ️  Display skipped (sampler test mode - results shown above)")
elif not show_results:
    print("ℹ️  Display disabled (set show_results = True to enable)")
else:
    print("ℹ️  No images to display")

## Cleanup

In [None]:
#@markdown Clean up and optionally end session
end_session = False #@param {type:"boolean"}

# Memory cleanup
del pipe, compel
torch.cuda.empty_cache()
gc.collect()

print("Cleanup complete")

if torch.cuda.is_available():
    print(f"GPU Memory after cleanup: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

if end_session:
    print("Ending Colab session...")
    sys.exit()