# Generate Before/After Images for Arcade AI Challenge

Creates all 16 required images (8 prompts × baseline + optimized)

This notebook implements the complete optimization strategy:
- **LoRA adapters** for specific jewelry categories
- **Special tokens** (sks, phol) for enhanced grounding
- **Native diffusers attention weighting** for jewelry terms
- **Optimal parameters** from human evaluation research


## Setup and Imports


In [None]:
import torch
import os
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
import time
from datetime import datetime
import matplotlib.pyplot as plt
from IPython.display import display, Image as IPImage
from PIL import Image

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


## Configuration


In [None]:
# The 8 required prompts (verbatim from challenge)
REQUIRED_PROMPTS = [
    "channel-set diamond eternity band, 2 mm width, hammered 18k yellow gold, product-only white background",
    "14k rose-gold threader earrings, bezel-set round lab diamond ends, lifestyle macro shot, soft natural light",
    "organic cluster ring with mixed-cut sapphires and diamonds, brushed platinum finish, modern aesthetic",
    "A solid gold cuff bracelet with blue sapphire, with refined simplicity and intentionally crafted for everyday wear",
    "modern signet ring, oval face, engraved gothic initial 'M', high-polish sterling silver, subtle reflection",
    "delicate gold huggie hoops, contemporary styling, isolated on neutral background",
    "stack of three slim rings: twisted gold, plain platinum, black rhodium pavé, editorial lighting",
    "bypass ring with stones on it, with refined simplicity and intentionally crafted for everyday wear"
]

# LoRA adapter paths and configuration
LORA_ADAPTERS = {
    "channel_set": "../lora_adapters/channel-set/checkpoint/pytorch_lora_weights.safetensors",
    "threader": "../lora_adapters/threader/checkpoint/pytorch_lora_weights.safetensors", 
    "huggie": "../lora_adapters/huggie/checkpoint/pytorch_lora_weights.safetensors"
}

# Special tokens for enhanced grounding
SPECIAL_TOKENS = {
    "channel_set": "sks",
    "threader": "phol"
}

print(f"✅ Configuration loaded: {len(REQUIRED_PROMPTS)} prompts, {len(LORA_ADAPTERS)} LoRA adapters")


## Helper Functions

Copy the helper functions from the Python script:


In [None]:
# Copy all functions from generate_before_after.py
exec(open('../notebook_or_scripts/generate_before_after.py').read().split('if __name__ == "__main__":')[0])

print("✅ All functions loaded from generate_before_after.py")


## Setup Pipeline


In [None]:
# Setup pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = setup_pipeline(device)


## Preview Prompts

First, let's see what the optimized prompts look like:


In [None]:
# Preview all optimized prompts
for i, prompt in enumerate(REQUIRED_PROMPTS[:3], 1):  # Show first 3 for preview
    category = detect_jewelry_category(prompt)
    enhanced = apply_jewelry_enhancement(prompt, category)
    
    print(f"=== PROMPT {i:02d} ===")
    print(f"Category: {'✅ LoRA: ' + category if category else '❌ No LoRA'}")
    print(f"Original: {prompt}")
    print(f"Enhanced: {enhanced}")
    print()


## Test Single Image

Let's test with one prompt to make sure everything works:


In [None]:
# Test with the first prompt
test_prompt = REQUIRED_PROMPTS[0]
print(f"Testing with: {test_prompt}")

# Generate baseline and optimized
print("\n🔸 Generating baseline...")
baseline_image = generate_baseline_image(pipeline, test_prompt, seed=42)

print("🔹 Generating optimized...")
optimized_image, enhanced_prompt, category = generate_optimized_image(pipeline, test_prompt, seed=42)

print(f"\n✅ Test complete!")
print(f"Enhanced prompt: {enhanced_prompt}")
print(f"Category: {category}")

# Display images side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.imshow(baseline_image)
ax1.set_title("Baseline (CFG=7.5)")
ax1.axis('off')

ax2.imshow(optimized_image)
ax2.set_title(f"Optimized (CFG=9.0, LoRA={category})")
ax2.axis('off')

plt.tight_layout()
plt.show()


## Generate All 16 Images

⚠️ **Warning**: This will take a while to complete (especially on CPU). Each image takes ~1-3 minutes to generate.

Run the complete generation using the function from the script:


In [None]:
# Generate all 16 images (8 baseline + 8 optimized)
# This uses the generate_all_comparisons() function from the script

results = generate_all_comparisons()

print("\n🎯 GENERATION COMPLETE!")
print("✅ All 16 deliverable images ready!")
print("📁 Check: deliverables/before_after/")
