# Image Style Matching Demo

This demo uses GraphGen to optimize a workflow for generating Pokemon-style images.

**What this demo does:**
1. Creates a dataset with Pokemon-style image generation tasks
2. Runs GraphGen optimization to find the best prompt workflow
3. Downloads the optimized graph and runs inference
4. Saves generated images to the results folder

In [None]:
# Parameters (can be overridden by papermill)
BACKEND_URL = None  # Will be set based on environment
API_KEY = None  # Will be set based on environment

In [None]:
# Step 1: Imports and Setup
import os
import sys
import json
import base64
import uuid
from pathlib import Path

import httpx
from dotenv import load_dotenv

load_dotenv()

# Add parent directory to path for imports
sys.path.insert(0, str(Path('.').resolve().parent.parent))

from synth_ai.sdk.api.train.graphgen import GraphGenJob, load_graphgen_taskset
from synth_ai.sdk.graphs.completions import GraphCompletionsSyncClient

print('Imports loaded successfully')

In [None]:
# Step 2: Configure Backend

# Use parameter if provided, otherwise check environment
if BACKEND_URL:
    SYNTH_API_BASE = BACKEND_URL
elif os.environ.get('LOCAL_BACKEND', '').lower() in ('1', 'true', 'yes'):
    SYNTH_API_BASE = 'http://127.0.0.1:8000/api'
else:
    SYNTH_API_BASE = os.environ.get('BACKEND_BASE_URL') or 'https://api.usesynth.ai'
    if not SYNTH_API_BASE.endswith('/api'):
        SYNTH_API_BASE = SYNTH_API_BASE.rstrip('/') + '/api'

print(f'Backend: {SYNTH_API_BASE}')

# Check backend health
r = httpx.get(f'{SYNTH_API_BASE}/health', timeout=30)
if r.status_code == 200:
    print(f'Backend health: {r.json()}')
else:
    raise RuntimeError(f'Backend not healthy: status {r.status_code}')

In [None]:
# Step 3: Get API Key

# Use parameter if provided, otherwise check environment
if API_KEY:
    SYNTH_API_KEY = API_KEY
else:
    SYNTH_API_KEY = os.environ.get('SYNTH_API_KEY', '')

if not SYNTH_API_KEY:
    print('No SYNTH_API_KEY found, minting demo key...')
    resp = httpx.post(f'{SYNTH_API_BASE}/api/demo/keys', json={'ttl_hours': 4}, timeout=30)
    resp.raise_for_status()
    SYNTH_API_KEY = resp.json()['api_key']
    print(f'Demo API Key: {SYNTH_API_KEY[:25]}...')
else:
    print(f'Using API Key: {SYNTH_API_KEY[:20]}...')

In [None]:
# Step 4: Create Dataset

def _load_gold_image(filename: str) -> str:
    """Load a Pokemon image from gold_images folder as base64 data URL."""
    img_path = Path("gold_images") / filename
    with open(img_path, "rb") as f:
        img_data = f.read()
    return f"data:image/png;base64,{base64.b64encode(img_data).decode('ascii')}"

# Define tasks with matching gold images
tasks = [
    {"id": "pokemon_dragon", "input": {"subject": "dragon", "style": "pokemon", "description": "A dragon creature in Pokemon art style"}},
    {"id": "pokemon_cat", "input": {"subject": "cat", "style": "pokemon", "description": "A cat creature in Pokemon art style"}},
    {"id": "pokemon_bird", "input": {"subject": "bird", "style": "pokemon", "description": "A bird creature in Pokemon art style"}},
]

# Load real Pokemon reference images
gold_image_files = ["charizard_dragon.png", "meowth_cat.png", "pidgeot_bird.png"]
gold_outputs = [
    {"task_id": task["id"], "output": {"image_url": _load_gold_image(gold_image_files[i]), "note": f"Reference for {task['input']['subject']}"}}
    for i, task in enumerate(tasks)
]

# Define schemas
input_schema = {
    "type": "object",
    "properties": {"subject": {"type": "string"}, "style": {"type": "string"}, "description": {"type": "string"}},
    "required": ["subject", "style", "description"]
}
output_schema = {
    "type": "object",
    "properties": {"image_url": {"type": "string", "description": "Base64-encoded image data URL"}},
    "required": ["image_url"]
}

# Create dataset
unique_id = uuid.uuid4().hex[:8]
dataset = {
    "version": "1.0",
    "metadata": {
        "name": f"pokemon-style-matching-{unique_id}",
        "description": "Pokemon art style matching with contrastive VLM judge.",
        "input_schema": input_schema,
        "output_schema": output_schema,
    },
    "initial_prompt": "Generate an image.",
    "tasks": tasks,
    "gold_outputs": gold_outputs,
    "input_schema": input_schema,
    "output_schema": output_schema,
    "default_rubric": {
        "outcome": {
            "criteria": [
                {"name": "pokemon_style_match", "description": "Image matches Pokemon art style", "weight": 1.0},
                {"name": "subject_recognition", "description": "Subject is recognizable", "weight": 0.8},
                {"name": "visual_quality", "description": "Image is high quality", "weight": 0.5},
            ]
        }
    },
    "judge_config": {"mode": "contrastive", "model": "gpt-4.1-nano", "provider": "openai"}
}

# Save dataset
dataset_path = Path("image_style_matching_dataset.json")
with open(dataset_path, "w") as f:
    json.dump(dataset, f, indent=2)

print(f'Created dataset: {dataset["metadata"]["name"]}')
print(f'  Tasks: {len(tasks)}, Gold outputs: {len(gold_outputs)}')
print(f'  Gold images: {gold_image_files}')
print(f'  Judge: {dataset["judge_config"]["mode"]} with {dataset["judge_config"]["model"]}')

In [None]:
# Step 5: Run GraphGen Optimization

dataset_obj = load_graphgen_taskset(dataset_path)

problem_spec = (
    "Generate images that match Pokemon art style. "
    "Use Gemini for image generation and a VLM judge for style matching."
)

print('Creating GraphGen job...')
job = GraphGenJob.from_dataset(
    dataset=dataset_obj,
    policy_model="gemini-2.5-flash-image",  # Image generation model
    rollout_budget=10,
    proposer_effort="medium",
    population_size=2,
    num_generations=1,
    problem_spec=problem_spec,
    backend_url=SYNTH_API_BASE,
    api_key=SYNTH_API_KEY,
    auto_start=True,
)

print(f'  Policy model: {job.config.policy_model}')
print(f'  Rollout budget: {job.config.rollout_budget}')

result = job.submit()
print(f'\nJob submitted: {result.graphgen_job_id}')

# Poll until complete
final_status = job.poll_until_complete(timeout=600.0, interval=5.0, progress=True)

print(f'\nFINAL STATUS: {final_status.get("status")}')
if final_status.get("status") == "succeeded":
    print(f'BEST SCORE: {final_status.get("best_score")}')
elif final_status.get("status") == "failed":
    print(f'ERROR: {final_status.get("error")}')

In [None]:
# Step 6: Download Optimized Graph

if final_status.get("status") == "succeeded":
    print('=' * 60)
    print('OPTIMIZED GRAPH')
    print('=' * 60)
    
    graph_txt = job.download_graph_txt()
    print(graph_txt)
    
    # Save graph
    results_dir = Path("results")
    results_dir.mkdir(exist_ok=True)
    
    with open(results_dir / "optimized_graph.txt", "w") as f:
        f.write(graph_txt)
    print(f'\nSaved graph to: {results_dir / "optimized_graph.txt"}')
else:
    print(f'Job did not succeed: {final_status.get("status")}')

In [None]:
# Step 7: Run Inference and Save Images

def save_image_from_data_url(data_url: str, output_path: Path) -> bool:
    """Save a base64 data URL to an image file."""
    if not data_url.startswith("data:image"):
        return False
    header, encoded = data_url.split(",", 1)
    with open(output_path, "wb") as f:
        f.write(base64.b64decode(encoded))
    return True

if final_status.get("status") == "succeeded":
    results_dir = Path("results")
    results_dir.mkdir(exist_ok=True)
    
    # Create typed sync client for graph completions
    graph_client = GraphCompletionsSyncClient(
        base_url=SYNTH_API_BASE,
        api_key=SYNTH_API_KEY,
        timeout=180.0,  # Image generation can take 2-3 minutes
    )
    
    job_id = result.graphgen_job_id
    
    test_inputs = [
        {"subject": "wolf", "style": "pokemon", "description": "A wolf creature in Pokemon art style"},
        {"subject": "fox", "style": "pokemon", "description": "A fox creature in Pokemon art style"},
        {"subject": "rabbit", "style": "pokemon", "description": "A rabbit creature in Pokemon art style"},
    ]
    
    print(f'Running inference on {len(test_inputs)} test inputs...')
    print(f'Job ID: {job_id}\n')
    
    for i, test_input in enumerate(test_inputs):
        print(f'Test {i+1}: {test_input["subject"]}')
        try:
            # Use typed sync client
            response = graph_client.run(job_id=job_id, input_data=test_input)
            print(f'  Cache: {response.cache_status or "unknown"}')
            
            # Find image in output
            nested = response.output
            image_url = None
            
            # Check node outputs
            for key, val in nested.items():
                if key.endswith("_output") and isinstance(val, dict) and "image_url" in val:
                    candidate = val["image_url"]
                    if isinstance(candidate, str) and len(candidate) > 5000:
                        image_url = candidate
                        print(f'  Found image in {key}: {len(image_url)} chars')
                        break
            
            if not image_url:
                candidate = nested.get("image_url", "")
                if isinstance(candidate, str) and len(candidate) > 5000:
                    image_url = candidate
            
            # Save image
            if image_url:
                img_path = results_dir / f"test_{i+1}_{test_input['subject']}.png"
                if save_image_from_data_url(image_url, img_path):
                    print(f'  Saved: {img_path}')
            else:
                print(f'  No valid image found')
                
        except Exception as e:
            print(f'  Error: {e}')
    
    print(f'\nResults saved to: {results_dir}')
else:
    print('Skipping inference (job did not succeed)')

In [None]:
# Cleanup temporary dataset file
if dataset_path.exists():
    dataset_path.unlink()
    print(f'Cleaned up: {dataset_path}')

print('\nDemo complete!')