# Optimized Image Generator

This notebook fetches configuration from the prompt-generator API and generates images using SDXL.

**Features:**
- Configuration pulled from web UI (no manual parameter editing)
- Three prompt modes: Full, Refined, Dual-Encoder
- Compel integration for long prompts
- Memory optimized for Colab free tier

In [None]:
# Cell 1: Setup & Configuration Fetch
import os
from dotenv import load_dotenv
from google.colab import drive
import requests
import warnings

# Suppress warnings
warnings.filterwarnings('ignore', message='Flax classes are deprecated')
warnings.filterwarnings('ignore', category=FutureWarning)

# Mount Google Drive
mount_path = '/content/drive'
if not os.path.exists(mount_path):
    print("Mounting Google Drive...")
    drive.mount(mount_path)
else:
    print("Drive already mounted.")

# Load HuggingFace token
env_path = '/content/drive/MyDrive/AI/hf_token.env'
load_dotenv(env_path)
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')

# API Configuration
API_BASE = "https://prompt-gen.squigglypickle.co.uk"

# Fetch configuration from API
print("\nFetching configuration from API...")
try:
    config_response = requests.get(f"{API_BASE}/image-config", timeout=10)
    config_response.raise_for_status()
    config = config_response.json()['config']
    print("\u2713 Configuration loaded successfully")
except Exception as e:
    print(f"\u2717 Failed to fetch config: {e}")
    print("Using default configuration...")
    config = {
        'prompt_mode': 'refined',
        'base_prompt': '',
        'sfw': False,
        'selected_categories': [],
        'model_id': 'John6666/wai-ani-nsfw-ponyxl-v11-sdxl',
        'download_model': False,
        'schedulers': ['DPM++ 2M', 'Euler'],
        'guidance_low': 5,
        'guidance_high': 15,
        'steps_low': 30,
        'steps_high': 40,
        'resolution_mode': 'sdxl',
        'num_images': 25,
        'prompts_per_batch': 3,
    }

# LoRA Configuration (set these manually — not fetched from API)
lora_enabled = False #@param {type:"boolean"}
lora_path = "Loras/stoo_tee/output/stoo_tee.safetensors" #@param {type:"string"}
lora_scale = 0.8 #@param {type:"slider", min:0.1, max:1.5, step:0.1}
lora_trigger_word = "stoo_tee" #@param {type:"string"}
lora_prepend_trigger = True #@param {type:"boolean"}

config['lora_enabled'] = lora_enabled
config['lora_path'] = f"/content/drive/MyDrive/{lora_path}" if lora_enabled else None
config['lora_scale'] = lora_scale
config['lora_trigger_word'] = lora_trigger_word
config['lora_prepend_trigger'] = lora_prepend_trigger

# Display configuration
print(f"\n=== Configuration ===")
print(f"  Prompt Mode: {config['prompt_mode']}")
print(f"  Model: {config['model_id']}")
print(f"  Schedulers: {', '.join(config['schedulers'])}")
print(f"  Images: {config['num_images']} x {len(config['schedulers'])} = {config['num_images'] * len(config['schedulers'])}")
print(f"  Guidance: {config['guidance_low']}-{config['guidance_high']}")
print(f"  Steps: {config['steps_low']}-{config['steps_high']}")
print(f"  Resolution: {config['resolution_mode']}")
print(f"  SFW: {config['sfw']}")
if lora_enabled:
    print(f"  LoRA: {lora_path} (scale={lora_scale}, trigger='{lora_trigger_word}', auto-prepend={'on' if lora_prepend_trigger else 'off'})")
else:
    print(f"  LoRA: disabled")

# Derived paths
from datetime import datetime
base_path = "/content/drive/MyDrive/AI/"
model_path = base_path + "models/" + config['model_id']
save_directory = f"{base_path}images/{datetime.now().strftime('%Y%m%d%H%M%S')}/"
os.makedirs(save_directory, exist_ok=True)
print(f"\nSave directory: {save_directory}")

In [9]:
# Cell 2: Install & Import Dependencies
!pip install -q diffusers transformers accelerate safetensors compel

import torch
import random
import json
import gc
from PIL import Image
from collections import defaultdict
from huggingface_hub import snapshot_download
from diffusers import (
    StableDiffusionXLPipeline,
    AutoencoderKL,
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    KDPM2DiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    HeunDiscreteScheduler,
    LMSDiscreteScheduler,
    DDIMScheduler,
    PNDMScheduler,
    UniPCMultistepScheduler
)
from compel import CompelForSDXL

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

PyTorch: 2.9.0+cu126
CUDA: True
GPU: Tesla T4
Memory: 14.17GB free


In [None]:
# Cell 3: Load Model & Pipeline
model_id = config['model_id']
model_exists = os.path.exists(model_path) and os.path.isdir(model_path) and len(os.listdir(model_path)) > 0

if config['download_model'] or not model_exists:
    print(f"Downloading model {model_id}...")
    snapshot_download(
        repo_id=model_id,
        local_dir=model_path,
        token=huggingface_token if huggingface_token else None,
        ignore_patterns=["*.safetensors.lock"]
    )
    print(f"\u2713 Model downloaded")
else:
    print(f"\u2713 Model exists at {model_path}")

# Load pipeline
print(f"\nLoading pipeline...")
load_kwargs = {
    "torch_dtype": torch.float16,
    "use_safetensors": True,
}

pipe = None
for attempt in ["fp16", "no_variant"]:
    try:
        if attempt == "fp16":
            load_kwargs["variant"] = "fp16"
        else:
            load_kwargs.pop("variant", None)

        try:
            pipe = StableDiffusionXLPipeline.from_pretrained(model_path, **load_kwargs)
            print(f"\u2713 Loaded from local path")
        except (OSError, FileNotFoundError):
            load_kwargs["token"] = huggingface_token
            pipe = StableDiffusionXLPipeline.from_pretrained(model_id, **load_kwargs)
            print(f"\u2713 Loaded from HuggingFace")
        break
    except ValueError as e:
        if "variant" in str(e) and attempt == "fp16":
            continue
        raise

# Load enhanced VAE
print("Loading enhanced VAE...")
try:
    vae = AutoencoderKL.from_pretrained(
        "madebyollin/sdxl-vae-fp16-fix",
        torch_dtype=torch.float16
    )
    pipe.vae = vae
    print("\u2713 Enhanced VAE applied")
except Exception as e:
    print(f"\u26a0 Could not load enhanced VAE: {e}")

# Move to GPU and optimize
pipe = pipe.to("cuda")
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

try:
    pipe.enable_xformers_memory_efficient_attention()
    print("\u2713 xformers enabled")
except:
    print("\u26a0 xformers not available")

# Load LoRA if enabled
if config['lora_enabled'] and config['lora_path']:
    lora_file = config['lora_path']
    if os.path.exists(lora_file):
        print(f"Loading LoRA from {lora_file}...")
        pipe.load_lora_weights(lora_file)
        pipe.fuse_lora(lora_scale=config['lora_scale'])
        print(f"\u2713 LoRA loaded and fused (scale={config['lora_scale']})")
        print(f"  Trigger word: '{config['lora_trigger_word']}' — include this in your prompts")
    else:
        print(f"\u2717 LoRA file not found: {lora_file}")
        config['lora_enabled'] = False

# Initialize Compel for prompt weighting
compel = CompelForSDXL(pipe)
print("\u2713 Compel initialized")

print(f"\n\u2713 Pipeline ready! GPU Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")

In [11]:
# Cell 4: Initialize Schedulers
def initialize_samplers(base_config):
    """Create scheduler instances from model's config."""
    clean_config = {k: v for k, v in base_config.items() if k not in ['beta_schedule']}

    return {
        "Model Default": lambda: pipe.scheduler.__class__.from_config(base_config),
        "DPM++ 2M": lambda: DPMSolverMultistepScheduler.from_config(base_config),
        "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(clean_config, use_karras_sigmas=True),
        "DPM++ SDE": lambda: DPMSolverSinglestepScheduler.from_config(base_config),
        "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(clean_config, use_karras_sigmas=True),
        "Euler": lambda: EulerDiscreteScheduler.from_config(base_config),
        "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(base_config),
        "Heun": lambda: HeunDiscreteScheduler.from_config(base_config),
        "KDPM2": lambda: KDPM2DiscreteScheduler.from_config(base_config),
        "KDPM2 a": lambda: KDPM2AncestralDiscreteScheduler.from_config(base_config),
        "LMS": lambda: LMSDiscreteScheduler.from_config(base_config),
        "DDIM": lambda: DDIMScheduler.from_config(base_config),
        "PNDM": lambda: PNDMScheduler.from_config(base_config),
        "UniPC": lambda: UniPCMultistepScheduler.from_config(base_config),
    }

SAMPLERS = initialize_samplers(pipe.scheduler.config)

# Validate selected schedulers
selected_schedulers = [s for s in config['schedulers'] if s in SAMPLERS]
if not selected_schedulers:
    selected_schedulers = ["Euler"]
    print("\u26a0 No valid schedulers found, defaulting to Euler")

print(f"\u2713 {len(selected_schedulers)} schedulers configured: {', '.join(selected_schedulers)}")

✓ 1 schedulers configured: DPM++ 2M


In [12]:
# Cell 5: Prompt Manager
NEGATIVE_PROMPT = "text, writing, bad teeth, deformed face, child, childish, young, deformed, extra fingers"

# Resolution pools
CLASSIC_RESOLUTIONS = [720, 768, 800, 1024]
SDXL_RESOLUTIONS = [
    (1024, 1024), (1152, 896), (896, 1152), (1216, 832),
    (832, 1216), (1344, 768), (768, 1344)
]

class PromptManager:
    def __init__(self, api_base, config):
        self.api_base = api_base
        self.config = config
        self.cache = []
        self.stats = defaultdict(int)

    def fetch_prompts(self, count=3):
        """Fetch prompts from API."""
        prompts = []
        for _ in range(count):
            try:
                params = {
                    'sfw': self.config['sfw'],
                    'encoder-split': 'true',  # Always get dual-encoder data
                }
                if self.config['base_prompt']:
                    params['base_prompt'] = self.config['base_prompt']
                if self.config['selected_categories']:
                    params['selected_categories'] = ' '.join(self.config['selected_categories'])

                response = requests.get(f"{self.api_base}/generate-refined", params=params, timeout=300)
                if response.status_code == 200:
                    data = response.json()
                    prompts.append({
                        'full': data.get('full_prompt', ''),
                        'refined': data.get('refined_prompt', ''),
                        'openclip_g': data.get('openclip_g', ''),
                        'clip_l': data.get('clip_l', ''),
                        'negative': data.get('negative_prompt', NEGATIVE_PROMPT)
                    })
                    self.stats['fetched'] += 1
            except Exception as e:
                print(f"Error fetching prompt: {e}")
                self.stats['failed'] += 1

        self.cache.extend(prompts)
        return len(prompts)

    def get_prompt(self):
        """Get a prompt from cache."""
        if len(self.cache) < 3:
            print(f"Cache low ({len(self.cache)}), fetching more...")
            self.fetch_prompts(self.config['prompts_per_batch'])

        if self.cache:
            self.stats['used'] += 1
            return self.cache.pop(0)
        return None

    def get_final_prompt(self, prompt_data):
        """Get the appropriate prompt based on prompt_mode."""
        mode = self.config['prompt_mode']

        if mode == 'dual_encoder':
            return {
                'type': 'dual',
                'prompt': prompt_data.get('openclip_g') or prompt_data.get('refined') or prompt_data['full'],
                'prompt_2': prompt_data.get('clip_l') or prompt_data.get('refined') or prompt_data['full'],
                'negative': prompt_data['negative']
            }
        elif mode == 'refined' and prompt_data.get('refined'):
            return {
                'type': 'single',
                'prompt': prompt_data['refined'],
                'negative': prompt_data['negative']
            }
        else:
            return {
                'type': 'single',
                'prompt': prompt_data['full'],
                'negative': prompt_data['negative']
            }

def get_resolution():
    """Get random resolution based on config."""
    if config['resolution_mode'] == 'classic':
        return (random.choice(CLASSIC_RESOLUTIONS), random.choice(CLASSIC_RESOLUTIONS))
    return random.choice(SDXL_RESOLUTIONS)

# Initialize
prompt_manager = PromptManager(API_BASE, config)
print(f"Pre-fetching {config['prompts_per_batch']} prompts...")
fetched = prompt_manager.fetch_prompts(config['prompts_per_batch'])
print(f"\u2713 Fetched {fetched} prompts")

Pre-fetching 3 prompts...
Error fetching prompt: HTTPSConnectionPool(host='prompt-gen.squigglypickle.co.uk', port=443): Read timed out. (read timeout=30)
✓ Fetched 2 prompts


In [None]:
# Cell 6: Generation Loop
metadata_list = []
metadata_path = os.path.join(save_directory, "metadata.json")

num_images = config['num_images']
total_images = num_images * len(selected_schedulers)

print(f"Starting generation: {num_images} prompts x {len(selected_schedulers)} schedulers = {total_images} images")
if config['lora_enabled']:
    if config['lora_prepend_trigger']:
        print(f"LoRA active: trigger word '{config['lora_trigger_word']}' will be prepended to prompts")
    else:
        print(f"LoRA active: auto-prepend OFF — include '{config['lora_trigger_word']}' in your prompts manually")
print(f"Save directory: {save_directory}\n")

image_counter = 0
for i in range(num_images):
    try:
        # Get prompt
        prompt_data = prompt_manager.get_prompt()
        if not prompt_data:
            print(f"[{i+1}/{num_images}] Failed to get prompt, skipping...")
            continue

        # Get final prompt based on mode
        final = prompt_manager.get_final_prompt(prompt_data)

        # Prepend LoRA trigger word if enabled and auto-prepend is on
        if config['lora_enabled'] and config['lora_prepend_trigger'] and config['lora_trigger_word']:
            trigger = config['lora_trigger_word']
            final['prompt'] = f"{trigger}, {final['prompt']}"
            if 'prompt_2' in final and final['prompt_2']:
                final['prompt_2'] = f"{trigger}, {final['prompt_2']}"

        # Random parameters (same for all schedulers)
        width, height = get_resolution()
        guidance_scale = round(random.uniform(config['guidance_low'], config['guidance_high']) * 2) / 2
        num_steps = random.randint(config['steps_low'], config['steps_high'])

        # Generate with each scheduler
        for scheduler_name in selected_schedulers:
            try:
                pipe.scheduler = SAMPLERS[scheduler_name]()
                seed = random.randint(0, 2**32 - 1)
                generator = torch.Generator(device="cuda").manual_seed(seed)

                # Generate based on prompt type
                if final['type'] == 'dual':
                    # Dual-encoder mode: pass separate prompts
                    result = pipe(
                        prompt=final['prompt'],
                        prompt_2=final['prompt_2'],
                        negative_prompt=final['negative'],
                        width=width,
                        height=height,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_steps,
                        generator=generator,
                    ).images[0]
                else:
                    # Single prompt mode: use Compel for weighting
                    conditioning = compel(final['prompt'], negative_prompt=final['negative'])
                    result = pipe(
                        prompt_embeds=conditioning.embeds,
                        pooled_prompt_embeds=conditioning.pooled_embeds,
                        negative_prompt_embeds=conditioning.negative_embeds,
                        negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds,
                        width=width,
                        height=height,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_steps,
                        generator=generator,
                    ).images[0]

                # Save image
                scheduler_short = scheduler_name.replace(' ', '_').replace('+', 'p')
                filename = f"{datetime.now().strftime('%Y%m%d%H%M%S')}_{str(image_counter).zfill(4)}_{scheduler_short}.png"
                result.save(os.path.join(save_directory, filename))

                # Save metadata
                metadata = {
                    "filename": filename,
                    "model": config['model_id'],
                    "scheduler": scheduler_name,
                    "prompt_mode": config['prompt_mode'],
                    "prompt": final['prompt'],
                    "prompt_2": final.get('prompt_2', ''),
                    "negative_prompt": final['negative'],
                    "sfw": config['sfw'],
                    "seed": seed,
                    "width": width,
                    "height": height,
                    "guidance_scale": guidance_scale,
                    "num_steps": num_steps,
                    "prompt_index": i,
                    "lora_enabled": config['lora_enabled'],
                    "lora_path": config.get('lora_path', ''),
                    "lora_scale": config.get('lora_scale', 0),
                    "lora_trigger_word": config.get('lora_trigger_word', ''),
                }
                metadata_list.append(metadata)

                with open(metadata_path, 'w') as f:
                    json.dump(metadata_list, f, indent=2)

                del result
                torch.cuda.empty_cache()
                image_counter += 1

            except Exception as e:
                print(f"Error with {scheduler_name}: {e}")
                continue

        # Progress
        print(f"[{i+1}/{num_images}] {width}x{height} | {num_steps} steps | G:{guidance_scale:.1f} | {len(selected_schedulers)} schedulers")

        gc.collect()

    except Exception as e:
        print(f"Error on image {i}: {e}")
        continue

print(f"\n\u2713 Generation complete!")
print(f"\u2713 {len(metadata_list)} images saved to: {save_directory}")
print(f"\u2713 Metadata saved to: {metadata_path}")
print(f"\nPrompt stats: Fetched={prompt_manager.stats['fetched']}, Used={prompt_manager.stats['used']}, Failed={prompt_manager.stats['failed']}")

In [None]:
# Cell 7: Display Results
import matplotlib.pyplot as plt

num_columns = 3
max_images = 12

if metadata_list:
    display_meta = metadata_list[:max_images]
    num_rows = (len(display_meta) + num_columns - 1) // num_columns

    plt.figure(figsize=(5 * num_columns, 5 * num_rows))

    for idx, meta in enumerate(display_meta):
        img = Image.open(os.path.join(save_directory, meta['filename']))
        plt.subplot(num_rows, num_columns, idx + 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"{meta['scheduler']}\n{meta['width']}x{meta['height']}", fontsize=8)

    plt.tight_layout()
    plt.show()
    print(f"Displayed {len(display_meta)} of {len(metadata_list)} images")
else:
    print("No images to display")

In [None]:
# Cell 8: Cleanup
del pipe, compel
torch.cuda.empty_cache()
gc.collect()

print("\u2713 Cleanup complete")
if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.mem_get_info()[0] / 1024**3:.2f}GB free")