# 🧪 Comprehensive Strategy Test: All 8 Prompts

## Goal
Test different enhancement strategies across all 8 jewelry prompts to understand:
1. **Why Compel performed poorly** in the main experiment
2. **Which weighting levels work best** (`+`, `++`, `+++`)
3. **Alternative approaches** that might work better
4. **Prompt-specific patterns** - do some prompts benefit more than others?

## Strategy Testing Framework
- **Baseline**: No modifications
- **Light Compel**: Single `+` weighting 
- **Medium Compel**: Double `++` weighting (original approach)
- **Heavy Compel**: Triple `+++` weighting
- **Numeric Weights**: Specific numeric values like `(term)1.2`
- **Negative Focus**: Enhanced negative prompts instead of positive weighting
- **Style Focus**: Add photography/quality terms instead of jewelry weighting

---


In [None]:
# Setup
import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import StableDiffusionXLPipeline
import matplotlib.pyplot as plt
import numpy as np
import os
from datetime import datetime

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Device: {device}")

# Load SDXL
print("🔄 Loading SDXL pipeline...")
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    variant="fp16", torch_dtype=torch.float16
).to(device)

compel = Compel(
    tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
    text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True],
)

# Create output directory
os.makedirs("strategy_test_results", exist_ok=True)
print("✅ Setup complete!")


In [None]:
# Define all 8 test prompts
base_prompts = [
    "channel-set diamond eternity band, 2 mm width, hammered 18k yellow gold, product-only white background",
    "14k rose-gold threader earrings, bezel-set round lab diamond ends, lifestyle macro shot, soft natural light",
    "organic cluster ring with mixed-cut sapphires and diamonds, brushed platinum finish, modern aesthetic",
    "A solid gold cuff bracelet with blue sapphire, with refined simplicity and intentionally crafted for everyday wear",
    "modern signet ring, oval face, engraved gothic initial 'M', high-polish sterling silver, subtle reflection",
    "delicate gold huggie hoops, contemporary styling, isolated on neutral background",
    "stack of three slim rings: twisted gold, plain platinum, black rhodium pavé, editorial lighting",
    "bypass ring with stones on it, with refined simplicity and intentionally crafted for everyday wear"
]

# Strategy generation functions
def create_light_compel(prompt):
    """Light weighting with single + """
    terms = {
        "channel-set": "channel-set+", "threader": "threader+", "bezel-set": "bezel-set+",
        "eternity band": "eternity band+", "huggie": "huggie+", "bypass": "bypass+",
        "pavé": "pavé+", "signet": "signet+", "cuff": "cuff+", "cluster": "cluster+",
        "diamond": "diamond+", "sapphire": "sapphire+", "gold": "gold+", "platinum": "platinum+",
        "engraved": "engraved+", "initial": "initial+", "'M'": "'M'+"
    }
    enhanced = prompt
    for term, weighted in terms.items():
        if term in prompt.lower():
            enhanced = enhanced.replace(term, weighted)
    return enhanced

def create_medium_compel(prompt):
    """Medium weighting with ++ (original approach)"""
    terms = {
        "channel-set": "channel-set++", "threader": "threader++", "bezel-set": "bezel-set++",
        "eternity band": "eternity band++", "huggie": "huggie++", "bypass": "bypass++",
        "pavé": "pavé++", "signet": "signet++", "cuff": "cuff++", "cluster": "cluster++",
        "diamond": "diamond++", "sapphire": "sapphire++", "gold": "gold++", "platinum": "platinum++",
        "engraved": "engraved++", "initial": "initial++", "'M'": "'M'++"
    }
    enhanced = prompt
    for term, weighted in terms.items():
        if term in prompt.lower():
            enhanced = enhanced.replace(term, weighted)
    return enhanced

def create_heavy_compel(prompt):
    """Heavy weighting with +++"""
    terms = {
        "channel-set": "channel-set+++", "threader": "threader+++", "bezel-set": "bezel-set+++",
        "eternity band": "eternity band+++", "huggie": "huggie+++", "bypass": "bypass+++",
        "pavé": "pavé+++", "signet": "signet+++", "cuff": "cuff+++", "cluster": "cluster+++",
        "diamond": "diamond+++", "sapphire": "sapphire+++", "gold": "gold+++", "platinum": "platinum+++",
        "engraved": "engraved+++", "initial": "initial+++", "'M'": "'M'+++"
    }
    enhanced = prompt
    for term, weighted in terms.items():
        if term in prompt.lower():
            enhanced = enhanced.replace(term, weighted)
    return enhanced

def create_numeric_weights(prompt):
    """Numeric weights like (term)1.2"""
    terms = {
        "channel-set": "(channel-set)1.3", "threader": "(threader)1.2", "bezel-set": "(bezel-set)1.3",
        "eternity band": "(eternity band)1.2", "huggie": "(huggie)1.2", "bypass": "(bypass)1.2",
        "pavé": "(pavé)1.3", "signet": "(signet)1.3", "cuff": "(cuff)1.2", "cluster": "(cluster)1.2",
        "diamond": "(diamond)1.2", "sapphire": "(sapphire)1.2", "gold": "(gold)1.1", "platinum": "(platinum)1.1",
        "engraved": "(engraved)1.4", "initial": "(initial)1.3", "'M'": "('M')1.5"
    }
    enhanced = prompt
    for term, weighted in terms.items():
        if term in prompt.lower():
            enhanced = enhanced.replace(term, weighted)
    return enhanced

def create_style_focus(prompt):
    """Add photography/quality terms instead of jewelry weighting"""
    return prompt + ", professional jewelry photography, macro lens, studio lighting, high-end luxury, premium quality"

# Enhanced negative prompt strategy
enhanced_negative = "vintage, ornate, fussy, cheap, low quality, blurry, deformed, ugly, amateur photography, poor lighting, plastic, fake, costume jewelry"

print("✅ Strategy functions defined!")


In [None]:
# Run comprehensive test across all strategies and prompts
strategies = {
    "baseline": lambda p: p,
    "light_compel": create_light_compel,
    "medium_compel": create_medium_compel,
    "heavy_compel": create_heavy_compel,
    "numeric_weights": create_numeric_weights,
    "style_focus": create_style_focus
}

print("🚀 Starting comprehensive strategy test...")
print(f"📊 Testing {len(strategies)} strategies × {len(base_prompts)} prompts = {len(strategies) * len(base_prompts)} generations")
print("⏱️ Estimated time: ~30-40 minutes")

# Store all results
all_results = {}

for strategy_name, strategy_func in strategies.items():
    print(f"\n🧪 Testing strategy: {strategy_name}")
    all_results[strategy_name] = {}
    
    for prompt_idx, base_prompt in enumerate(base_prompts, 1):
        print(f"  📝 Prompt {prompt_idx}/8: {base_prompt[:50]}...")
        
        # Apply strategy
        modified_prompt = strategy_func(base_prompt)
        
        # Choose negative prompt
        neg_prompt = enhanced_negative if strategy_name == "negative_focus" else "vintage, ornate, fussy, cheap, low quality, blurry"
        
        try:
            if strategy_name == "baseline" or strategy_name == "style_focus":
                # Standard generation
                image = pipe(
                    prompt=modified_prompt,
                    negative_prompt=neg_prompt,
                    num_inference_steps=25,
                    guidance_scale=5.0,
                    width=768, height=768,
                    generator=torch.Generator(device=device).manual_seed(100 + prompt_idx)
                ).images[0]
            else:
                # Compel generation
                cond, pooled = compel([modified_prompt, neg_prompt])
                image = pipe(
                    prompt_embeds=cond[0:1],
                    pooled_prompt_embeds=pooled[0:1],
                    negative_prompt_embeds=cond[1:2],
                    negative_pooled_prompt_embeds=pooled[1:2],
                    num_inference_steps=25,
                    guidance_scale=5.0,
                    width=768, height=768,
                    generator=torch.Generator(device=device).manual_seed(100 + prompt_idx)
                ).images[0]
            
            # Save image
            filename = f"strategy_test_results/p{prompt_idx:02d}_{strategy_name}.png"
            image.save(filename)
            
            # Store result
            all_results[strategy_name][prompt_idx] = {
                'original_prompt': base_prompt,
                'modified_prompt': modified_prompt,
                'image': image,
                'filepath': filename
            }
            
            print(f"    ✅ Generated and saved: p{prompt_idx:02d}_{strategy_name}.png")
            
        except Exception as e:
            print(f"    ❌ Failed: {e}")
            continue

print(f"\n🎉 Comprehensive test completed!")
print(f"📁 Results saved in: strategy_test_results/")

# Quick summary
total_generated = sum(len(results) for results in all_results.values())
print(f"📊 Successfully generated: {total_generated} images")


In [None]:
# Create visual comparison grids for each prompt
print("🖼️ Creating visual comparison grids...")

for prompt_idx in range(1, len(base_prompts) + 1):
    # Check which strategies have results for this prompt
    available_strategies = []
    strategy_images = []
    
    for strategy_name in strategies.keys():
        if strategy_name in all_results and prompt_idx in all_results[strategy_name]:
            available_strategies.append(strategy_name)
            strategy_images.append(all_results[strategy_name][prompt_idx]['image'])
    
    if len(available_strategies) >= 2:  # Only create grid if we have multiple results
        # Create subplot grid
        cols = 3
        rows = (len(available_strategies) + cols - 1) // cols
        fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows))
        
        if rows == 1:
            axes = [axes] if cols == 1 else axes
        else:
            axes = axes.flatten()
        
        # Plot images
        for i, (strategy_name, image) in enumerate(zip(available_strategies, strategy_images)):
            if i < len(axes):
                axes[i].imshow(image)
                axes[i].set_title(f"{strategy_name}", fontweight='bold', fontsize=12)
                axes[i].axis('off')
        
        # Hide unused subplots
        for i in range(len(available_strategies), len(axes)):
            axes[i].axis('off')
        
        # Add main title
        original_prompt = base_prompts[prompt_idx - 1]
        fig.suptitle(f"Prompt {prompt_idx}: {original_prompt[:60]}...", fontsize=14, fontweight='bold')
        
        plt.tight_layout()
        plt.savefig(f"strategy_test_results/comparison_grid_p{prompt_idx:02d}.png", dpi=150, bbox_inches='tight')
        plt.show()
        
        print(f"✅ Created comparison grid for prompt {prompt_idx}")

print("🎨 Visual comparison grids completed!")


In [None]:
# Export detailed results to CSV for analysis
import pandas as pd

print("📊 Exporting results to CSV...")

csv_data = []
for strategy_name, strategy_results in all_results.items():
    for prompt_idx, result in strategy_results.items():
        row = {
            'prompt_id': prompt_idx,
            'strategy': strategy_name,
            'original_prompt': result['original_prompt'],
            'modified_prompt': result['modified_prompt'],
            'image_path': result['filepath']
        }
        csv_data.append(row)

# Create DataFrame
df = pd.DataFrame(csv_data)
csv_path = "strategy_test_results/comprehensive_strategy_results.csv"
df.to_csv(csv_path, index=False)

print(f"💾 Saved comprehensive results to: {csv_path}")
print(f"📋 Total entries: {len(df)}")

# Display summary statistics
print(f"\n📈 Results Summary:")
print(f"{'Strategy':<15} {'Images Generated':<15} {'Success Rate':<12}")
print("-" * 45)

for strategy_name in strategies.keys():
    if strategy_name in all_results:
        generated = len(all_results[strategy_name])
        success_rate = (generated / len(base_prompts)) * 100
        print(f"{strategy_name:<15} {generated:<15} {success_rate:.1f}%")

# Preview of results
print(f"\n📋 Sample Results:")
print(df[['prompt_id', 'strategy', 'original_prompt', 'modified_prompt']].head(3).to_string(max_colwidth=40))


## 📋 Analysis Framework

### 🔍 **What to Look For in Results:**

**Visual Quality Indicators:**
- **Clarity & sharpness** - Are images crisp and well-defined?
- **Jewelry details** - Are specific features (channel-set, engraving, etc.) visible?
- **Lighting & composition** - Professional vs amateur appearance
- **Material rendering** - Do metals look realistic?

**Strategy-Specific Questions:**
1. **Light Compel (`+`)** - Does lighter weighting work better than heavy (`++`)?
2. **Heavy Compel (`+++`)** - Does maximum weighting help or hurt?
3. **Numeric Weights** - Are specific values like `(term)1.2` more effective?
4. **Style Focus** - Does adding photography terms improve overall quality?

**Prompt-Specific Patterns:**
- **Prompt 5 (Signet 'M')** - Which strategy makes the letter most visible?
- **Complex prompts** - Do simpler or more detailed strategies work better?
- **Material-heavy prompts** - How do different approaches handle multiple metals?

### 🎯 **Key Success Metrics:**
1. **Visual appeal** - Which images look most professional/luxury?
2. **Prompt adherence** - Which capture the specific jewelry features best?
3. **Consistency** - Which strategy works across multiple prompt types?
4. **Detail preservation** - Which maintains fine jewelry details?

### 💡 **Next Steps After Analysis:**
- Identify the best-performing strategy overall
- Note prompt-specific patterns (some prompts might need different approaches)
- Test hybrid approaches combining successful elements
- Update the main pipeline with winning strategies

---

**🏆 Goal**: Find the optimal enhancement strategy that consistently improves jewelry generation quality!
