# GEPA Refiner Prompt Optimization

This notebook runs GEPA optimization on the contrastive verifier's refiner prompts using Google Colab with Drive.

## 1. Mount Google Drive & Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Set the path to your contrastive_verifier directory
# Modify this path to match your Google Drive folder structure
DRIVE_PATH = "/content/drive/MyDrive/contrastive_verifier"

import os
os.chdir(DRIVE_PATH)
print(f"Working directory: {os.getcwd()}")
print(f"Files: {os.listdir('.')}")

In [None]:
# Install dependencies
!pip install -q transformers accelerate python-dotenv tqdm

## 2. Setup HuggingFace Token (for Gemma access)

In [None]:
# Option 1: Set token directly (replace with your token)
import os
os.environ["HF_TOKEN"] = "your_huggingface_token_here"  # <-- REPLACE THIS
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ["HF_TOKEN"]

# Option 2: Or load from .env file if it exists
# from dotenv import load_dotenv
# load_dotenv()

## 3. Import and Run GEPA Optimization

In [None]:
# Add paths
import sys
sys.path.insert(0, DRIVE_PATH)
sys.path.insert(0, os.path.join(DRIVE_PATH, 'gepa/src'))

# Import from our modules
from gepa_refiner_adapter import (
    ContrastiveVerifierAdapter,
    load_verifier,
    load_refiner_model,
    prepare_training_data,
)

print("Imports successful!")

In [None]:
# Configuration
CONFIG = {
    "max_iterations": 25,
    "minibatch_size": 10,
    "max_train_samples": -1,  # -1 for all samples
    "score_threshold": 0.7,
    "verifier_checkpoint": "verifier_best.pt",
    "scored_file": "scored_outputs.jsonl",
    "original_file": "generations_google-gemma-2b-it_0_-1.jsonl",
    "output_file": "evolved_prompts.json",
    "device": "cuda"
}

# Seed prompts (baseline)
SEED_PROMPTS = {
    "critique_prompt": """Review this solution step by step. Identify any errors in mathematical reasoning, calculations, or logic. Be specific about what went wrong.""",
    "refinement_prompt": """Based on the critique above, provide a corrected step-by-step solution. End your final answer with \\boxed{answer}."""
}

print("Configuration set!")

In [None]:
# Load models
print("="*60)
print("1. Loading verifier model...")
verifier_model, verifier_tokenizer = load_verifier(
    CONFIG["verifier_checkpoint"], 
    CONFIG["device"]
)

print("\n2. Loading refiner model (also used for reflection)...")
refiner_model, refiner_tokenizer = load_refiner_model(
    "google/gemma-2-2b-it", 
    CONFIG["device"]
)

print("\nModels loaded!")

In [None]:
# Prepare training data
print("3. Preparing training data...")
trainset, valset = prepare_training_data(
    scored_file=CONFIG["scored_file"],
    original_file=CONFIG["original_file"],
    max_samples=CONFIG["max_train_samples"],
    score_threshold=CONFIG["score_threshold"]
)

In [None]:
# Create adapter
print("4. Creating GEPA adapter...")
adapter = ContrastiveVerifierAdapter(
    verifier_model=verifier_model,
    verifier_tokenizer=verifier_tokenizer,
    refiner_model=refiner_model,
    refiner_tokenizer=refiner_tokenizer,
    device=CONFIG["device"]
)
print("Adapter created!")

In [None]:
# Import the optimizer class from the script
# We'll define it inline to avoid import issues

import torch
import re
import random
from datetime import datetime
import json

# Reflection prompt template
REFLECTION_PROMPT_TEMPLATE = """I provided an assistant with the following instructions to generate critiques and refinements for math solutions:

```
<curr_instructions>
```

The following are examples of different inputs provided to the assistant, along with its critique/refinement outputs and feedback on how well the refinement worked:

```
<inputs_outputs_feedback>
```

Your task is to write improved instructions for the assistant.

Key observations from the examples:
1. Look at cases where refinement SUCCEEDED - what made the critique effective?
2. Look at cases where refinement FAILED - was the critique missing the actual error?
3. Look at REGRESSIONS where correct solutions became incorrect - what went wrong?

When writing new instructions:
- Include specific strategies that worked in successful cases
- Add warnings about common failure patterns observed
- Include domain-specific math reasoning guidance
- Be explicit about how to identify calculation vs. conceptual errors

Provide the new instructions within ``` blocks.
"""

class SimpleGEPAOptimizer:
    def __init__(self, adapter, trainset, valset, reflection_model, reflection_tokenizer, device="cuda"):
        self.adapter = adapter
        self.trainset = trainset
        self.valset = valset
        self.reflection_model = reflection_model
        self.reflection_tokenizer = reflection_tokenizer
        self.device = device
        self.candidates = []
        self.scores = []
        self.history = []
    
    def _evaluate_on_valset(self, candidate):
        eval_result = self.adapter.evaluate(self.valset, candidate, capture_traces=False)
        return sum(eval_result.scores) / len(eval_result.scores) if eval_result.scores else 0.0
    
    def _sample_minibatch(self, size=5):
        return random.sample(self.trainset, min(size, len(self.trainset)))
    
    def _generate_reflection(self, current_prompt, reflective_dataset, component_name):
        formatted_examples = []
        for i, record in enumerate(reflective_dataset):
            example = f"""### Example {i+1}
Question: {record.get('Question', 'N/A')[:200]}...
Original Solution (excerpt): {record.get('Original_Solution', 'N/A')[:200]}...
Critique Generated: {record.get('Critique_Generated', 'N/A')[:300]}...
Refined Solution (excerpt): {record.get('Refined_Solution', 'N/A')[:200]}...
Feedback: {record.get('Feedback', 'N/A')}
"""
            formatted_examples.append(example)
        
        examples_text = "\n\n".join(formatted_examples)
        reflection_prompt = REFLECTION_PROMPT_TEMPLATE.replace(
            "<curr_instructions>", current_prompt
        ).replace(
            "<inputs_outputs_feedback>", examples_text
        )
        
        full_prompt = f"""<start_of_turn>user
{reflection_prompt}<end_of_turn>
<start_of_turn>model
"""
        
        inputs = self.reflection_tokenizer(
            full_prompt, return_tensors="pt", truncation=True, max_length=4096
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.reflection_model.generate(
                **inputs, max_new_tokens=1024, temperature=0.7, do_sample=True, top_p=0.9,
                pad_token_id=self.reflection_tokenizer.eos_token_id
            )
        
        response = self.reflection_tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "<start_of_turn>model" in response:
            response = response.split("<start_of_turn>model")[-1].strip()
        
        match = re.search(r'```(?:\w*\n)?(.+?)```', response, re.DOTALL)
        if match:
            return match.group(1).strip()
        if "```" in response:
            parts = response.split("```")
            if len(parts) > 1:
                return parts[1].strip()
        return response.strip()
    
    def optimize(self, seed_candidate, max_iterations=10, minibatch_size=5, components_to_optimize=None):
        if components_to_optimize is None:
            components_to_optimize = ["critique_prompt", "refinement_prompt"]
        
        current_candidate = seed_candidate.copy()
        self.candidates.append(current_candidate.copy())
        
        seed_score = self._evaluate_on_valset(current_candidate)
        self.scores.append(seed_score)
        print(f"Seed prompt score: {seed_score:.4f}")
        
        best_candidate = current_candidate.copy()
        best_score = seed_score
        
        for iteration in range(max_iterations):
            print(f"\n{'='*60}")
            print(f"Iteration {iteration + 1}/{max_iterations}")
            print(f"{'='*60}")
            
            component_idx = iteration % len(components_to_optimize)
            component = components_to_optimize[component_idx]
            print(f"Optimizing component: {component}")
            
            minibatch = self._sample_minibatch(minibatch_size)
            eval_result = self.adapter.evaluate(minibatch, current_candidate, capture_traces=True)
            minibatch_score = sum(eval_result.scores) / len(eval_result.scores)
            print(f"Minibatch score: {minibatch_score:.4f}")
            
            reflective_dataset = self.adapter.make_reflective_dataset(
                current_candidate, eval_result, [component]
            )
            
            print("Generating improved prompt via reflection...")
            new_prompt = self._generate_reflection(
                current_candidate[component],
                list(reflective_dataset.get(component, [])),
                component
            )
            
            new_candidate = current_candidate.copy()
            new_candidate[component] = new_prompt
            
            print(f"\nProposed new {component}:\n{new_prompt[:300]}...")
            
            new_eval = self.adapter.evaluate(minibatch, new_candidate, capture_traces=False)
            new_minibatch_score = sum(new_eval.scores) / len(new_eval.scores)
            print(f"New minibatch score: {new_minibatch_score:.4f}")
            
            if new_minibatch_score > minibatch_score:
                print("✓ Minibatch improved - evaluating on full valset...")
                new_val_score = self._evaluate_on_valset(new_candidate)
                print(f"Validation score: {new_val_score:.4f} (was {best_score:.4f})")
                
                self.candidates.append(new_candidate.copy())
                self.scores.append(new_val_score)
                current_candidate = new_candidate.copy()
                
                if new_val_score > best_score:
                    best_score = new_val_score
                    best_candidate = new_candidate.copy()
                    print(f"★ New best! Score: {best_score:.4f}")
            else:
                print("✗ No improvement on minibatch - keeping current candidate")
            
            self.history.append({
                "iteration": iteration + 1,
                "component": component,
                "minibatch_score_before": minibatch_score,
                "minibatch_score_after": new_minibatch_score,
                "accepted": new_minibatch_score > minibatch_score,
                "best_val_score": best_score
            })
        
        print(f"\n{'='*60}")
        print(f"Optimization complete!")
        print(f"Best validation score: {best_score:.4f} (started at {seed_score:.4f})")
        print(f"Improvement: {(best_score - seed_score):.4f} ({(best_score - seed_score) / max(seed_score, 0.001) * 100:.1f}%)")
        print(f"{'='*60}")
        
        return best_candidate

print("Optimizer class defined!")

In [None]:
# Create and run optimizer
print("5. Setting up optimizer...")
optimizer = SimpleGEPAOptimizer(
    adapter=adapter,
    trainset=trainset,
    valset=valset,
    reflection_model=refiner_model,
    reflection_tokenizer=refiner_tokenizer,
    device=CONFIG["device"]
)

print("\n6. Running GEPA optimization...")
best_prompts = optimizer.optimize(
    seed_candidate=SEED_PROMPTS,
    max_iterations=CONFIG["max_iterations"],
    minibatch_size=CONFIG["minibatch_size"]
)

In [None]:
# Save results
print(f"\n7. Saving results to {CONFIG['output_file']}...")

results = {
    "timestamp": datetime.now().isoformat(),
    "config": CONFIG,
    "seed_prompts": SEED_PROMPTS,
    "evolved_prompts": best_prompts,
    "history": optimizer.history,
    "final_scores": {
        "seed": optimizer.scores[0] if optimizer.scores else 0,
        "best": max(optimizer.scores) if optimizer.scores else 0
    }
}

with open(CONFIG["output_file"], 'w') as f:
    json.dump(results, f, indent=2)

print(f"Done! Results saved to {CONFIG['output_file']}")

In [None]:
# Display evolved prompts
print("\n" + "="*60)
print("EVOLVED PROMPTS")
print("="*60)
for name, prompt in best_prompts.items():
    print(f"\n### {name}:\n{prompt}")
print("\n" + "="*60)

In [None]:
# Plot optimization history
import matplotlib.pyplot as plt

iterations = [h['iteration'] for h in optimizer.history]
best_scores = [h['best_val_score'] for h in optimizer.history]

plt.figure(figsize=(10, 5))
plt.plot(iterations, best_scores, 'b-o')
plt.xlabel('Iteration')
plt.ylabel('Best Validation Score')
plt.title('GEPA Optimization Progress')
plt.grid(True)
plt.show()