# Plaited World Agent GRPO Training

Train the world agent with Group Relative Policy Optimization (GRPO) using browser execution for reward computation.

## Why Local?

GRPO requires browser execution for reward computation. The model generates UI code, which is validated by:
1. **Tier 1: Static analysis** - Fast pattern checks (free)
2. **Tier 3: Browser execution** - Story tests in Playwright (ground truth)

Colab doesn't provide browser access, so GRPO runs locally with Playwright.

## Prerequisites

1. **Complete SFT/DPO training** on Colab first (see `plaited-world-agent-training.ipynb`)
2. **Download trained model** from HuggingFace
3. **Install dependencies**:

```bash
pip install unsloth trl datasets transformers torch
bun install  # For Plaited workshop
bunx playwright install chromium
```

## Environment Variables

```bash
export HF_TOKEN="your-token"
export HF_USERNAME="your-username"
export HF_MODEL_NAME="plaited-world-agent-lora"
```

In [None]:
# Cell 1: Load Environment and Model
import os
from unsloth import FastLanguageModel

hf_token = os.environ.get('HF_TOKEN')
hf_username = os.environ.get('HF_USERNAME', 'plaited')
hf_model_name = os.environ.get('HF_MODEL_NAME', 'plaited-world-agent-lora')

# Load the SFT/DPO trained model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=f"{hf_username}/{hf_model_name}",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
    token=hf_token,
)

print(f"Loaded model: {hf_username}/{hf_model_name}")
print(f"Parameters: {model.num_parameters():,}")

In [None]:
# Cell 2: Load Intent Dataset
from datasets import load_dataset

# Load trajectories but only use intents for GRPO
# The model generates responses, browser validates
dataset = load_dataset("json", data_files="trajectories.jsonl", split="train")

# Extract just the intents for GRPO
intent_dataset = dataset.map(lambda x: {
    "prompt": x["messages"][1]["content"]  # User message with intent
})

print(f"Loaded {len(intent_dataset)} intents for GRPO")
print(f"Sample intent: {intent_dataset[0]['prompt'][:100]}...")

In [None]:
# Cell 3: Static Analysis Integration
import subprocess
import json
import tempfile
import os

def run_static_analysis(code: str) -> dict:
    """
    Run Tier 1 static analysis via Bun.
    
    Returns:
        {"passed": bool, "checks": [...], "tier": 1}
    """
    # Write code to temp file
    with tempfile.NamedTemporaryFile(mode='w', suffix='.tsx', delete=False) as f:
        f.write(code)
        temp_path = f.name
    
    try:
        # Run static analysis script
        result = subprocess.run(
            ["bun", "run", "-e", f"""
import {{ runStaticAnalysis }} from 'plaited/agent-next'
const code = await Bun.file('{temp_path}').text()
const result = runStaticAnalysis(code)
console.log(JSON.stringify(result))
            """],
            capture_output=True,
            text=True,
            timeout=10
        )
        
        if result.returncode == 0:
            return json.loads(result.stdout.strip())
        else:
            return {"passed": False, "tier": 1, "checks": [{"name": "parse", "passed": False, "message": result.stderr}]}
    finally:
        os.unlink(temp_path)

# Test static analysis
test_code = '<button aria-label="Test">Click me</button>'
result = run_static_analysis(test_code)
print(f"Static analysis test: {result}")

In [None]:
# Cell 4: Browser Reward Function
import subprocess
import json
import tempfile
import os

def run_browser_test(code: str) -> dict:
    """
    Run Tier 3 browser test via workshop CLI.
    
    Returns:
        {"passed": bool, "a11yPassed": bool, "totalAssertions": int, "passedAssertions": int}
    """
    # Write generated code to temp story file
    with tempfile.NamedTemporaryFile(mode='w', suffix='.stories.tsx', delete=False) as f:
        # Wrap code in story format
        story_code = f"""
import {{ story }} from 'plaited/testing'

{code}

export const Generated = story({{
  template: () => GeneratedTemplate(),
  intent: 'Generated for GRPO training',
  play: async ({{ assert }}) => {{
    await assert.a11y()
  }}
}})
"""
        f.write(story_code)
        temp_path = f.name
    
    try:
        # Run workshop test
        result = subprocess.run(
            ["bun", "plaited", "test", temp_path, "--json"],
            capture_output=True,
            text=True,
            timeout=30
        )
        
        if result.returncode == 0:
            return json.loads(result.stdout)
        else:
            return {
                "passed": False,
                "a11yPassed": False,
                "totalAssertions": 0,
                "passedAssertions": 0,
                "error": result.stderr
            }
    finally:
        os.unlink(temp_path)

print("Browser test function ready")

In [None]:
# Cell 5: GRPO Reward Function

def compute_reward(intent: str, generation: str) -> float:
    """
    Compute reward from tiered analysis.
    
    Reward breakdown:
    - Tier 1 (Static): 0.0 if fail, continue if pass (fast rejection)
    - Tier 3 (Browser): 
        - Story passed: +0.5
        - A11y passed: +0.3
        - Assertion ratio: +0.2 * (passed/total)
    
    Args:
        intent: The user intent that prompted generation
        generation: The model's generated code
    
    Returns:
        Reward value between 0.0 and 1.0
    """
    # Tier 1: Static analysis (fast rejection)
    static_result = run_static_analysis(generation)
    if not static_result.get("passed", False):
        return 0.0  # Fast rejection for invalid code
    
    # Tier 3: Browser test (ground truth)
    browser_result = run_browser_test(generation)
    
    reward = 0.0
    
    # Story passed: +0.5
    if browser_result.get("passed", False):
        reward += 0.5
    
    # A11y passed: +0.3
    if browser_result.get("a11yPassed", False):
        reward += 0.3
    
    # Assertion ratio: +0.2 * ratio
    total = browser_result.get("totalAssertions", 0)
    passed = browser_result.get("passedAssertions", 0)
    if total > 0:
        reward += 0.2 * (passed / total)
    
    return reward

print("Reward function ready")
print("Reward breakdown: Static(0.0 if fail) + Story(0.5) + A11y(0.3) + Assertions(0.2)")

In [None]:
# Cell 6: Configure GRPO Training
from trl import GRPOConfig, GRPOTrainer

grpo_config = GRPOConfig(
    output_dir="./grpo-output",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-6,  # Very low LR for GRPO
    logging_steps=10,
    save_steps=50,
    fp16=True,
    max_length=2048,
    # GRPO specific
    num_generations=4,  # Generate 4 candidates per intent
    temperature=0.8,    # Higher temperature for diversity
)

grpo_trainer = GRPOTrainer(
    model=model,
    args=grpo_config,
    train_dataset=intent_dataset,
    reward_fn=compute_reward,
    tokenizer=tokenizer,
)

print("GRPO Trainer configured")
print(f"Generating {grpo_config.num_generations} candidates per intent")
print(f"Temperature: {grpo_config.temperature}")

In [None]:
# Cell 7: Train with GRPO
print("Starting GRPO training...")
print("This will generate code and validate in browser for each intent.")
print("Progress is logged every 10 steps.")
print()

grpo_trainer.train()

print()
print("GRPO training complete!")

In [None]:
# Cell 8: Save and Push to Hub
import os

hf_token = os.environ.get('HF_TOKEN')
hf_username = os.environ.get('HF_USERNAME', 'plaited')
hf_model_name = os.environ.get('HF_MODEL_NAME', 'plaited-world-agent-lora')

MODEL_NAME = f"{hf_username}/{hf_model_name}-grpo"

# Save locally first
model.save_pretrained(f"./{hf_model_name}-grpo")
tokenizer.save_pretrained(f"./{hf_model_name}-grpo")
print(f"Saved locally to ./{hf_model_name}-grpo")

# Push to HuggingFace Hub
model.push_to_hub(MODEL_NAME, token=hf_token)
tokenizer.push_to_hub(MODEL_NAME, token=hf_token)

print(f"\nModel pushed to: https://huggingface.co/{MODEL_NAME}")

## Next Steps

1. **Deploy to HuggingFace Inference Endpoints** with vLLM
2. **Connect from TypeScript**:

```typescript
import { InferenceClient } from '@huggingface/inference'
import { useWorldAgent, createCoreTools } from 'plaited/agent'

const client = new InferenceClient(process.env.HF_TOKEN)

const trigger = await useWorldAgent({
  tools: createCoreTools({ outputDir: './generated' }),
  model: {
    inference: async (intent, schemas) => {
      const response = await client.chatCompletion({
        model: 'your-username/plaited-world-agent-lora-grpo',
        endpointUrl: 'https://xxx.endpoints.huggingface.cloud',
        messages: [{ role: 'user', content: intent }],
        tools: schemas.map(s => ({ type: 'function', function: s }))
      })
      return response.choices[0]?.message?.tool_calls ?? []
    }
  }
})
```

## Troubleshooting

### Browser Tests Failing
- Ensure Playwright is installed: `bunx playwright install chromium`
- Check workshop CLI works: `bun plaited test training/stories`

### Out of Memory
- Reduce `per_device_train_batch_size` to 1
- Reduce `num_generations` to 2

### Slow Training
- GRPO is inherently slower due to browser validation
- Consider running overnight for larger datasets