# GEPA Summarization Optimization with LLM Judge Evaluation
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)

## Introduction

This notebook demonstrates how to optimize summarization prompts using GEPA (Generate, Evaluate, Propose, Adapt) with the our Evaluations API. We'll:

1. Load the CNN/DailyMail dataset containing news articles
2. Start with a baseline summarization prompt
3. Use an optimizer LLM to iteratively improve the prompt
4. Compare prompts head-to-head using a judge model
5. Track improvement over multiple iterations

**Concepts Covered:**
- **GEPA Optimization**: Iterative prompt engineering using LLM feedback
- **LLM-as-a-Judge**: Using a language model to evaluate and compare outputs
- **Batch Evaluation**: Efficient comparison of multiple summaries
- **Prompt Engineering**: Systematic improvement of instruction prompts

## üì¶ Setup and Installation

In [None]:
!pip install -qU together dspy-ai datasets tqdm

In [None]:
import together
import json
import random
import os
import re
import time
from pathlib import Path
from typing import List, Dict, Tuple
from datetime import datetime

import dspy
from datasets import load_dataset
from tqdm import tqdm

## ‚öôÔ∏è Configuration

Set up your API key and configure the models we'll use:
- **Summarizer Model**: Generates the summaries
- **Judge Model**: Evaluates which summary is better
- **Optimizer Model**: Proposes improvements to the prompt

In [None]:
client = together.Client()

# Model configuration
SUMMARIZER_MODEL = "openai/gpt-oss-20b"
JUDGE_MODEL = "deepseek-ai/DeepSeek-V3"
OPTIMIZER_MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"

# Data splits
TRAIN_SIZE = 150
VAL_SIZE = 300
TEST_SIZE = 300

RANDOM_SEED = 42

print("‚úì Configuration complete")

## üìù Baseline and Judge Prompts

We start with a simple baseline prompt for summarization. The GEPA process will iteratively improve this prompt based on performance feedback.

In [None]:
BASELINE_PROMPT = """Summarize this news article in 3-5 key points.

Write a brief summary covering:
- The main news event
- Key people or organizations involved
- Important details or outcomes
- Any significant context

Keep it to 3-5 sentences total."""

JUDGE_PROMPT = """Compare these two summaries of the same news article.

Which summary better:
- Captures the main news story
- Includes important details
- Is clear and concise
- Avoids unnecessary information

Choose A or B and explain why briefly."""

print("Baseline Prompt:")
print(BASELINE_PROMPT)
print("\nJudge Prompt:")
print(JUDGE_PROMPT)

## üìÇ Loading the CNN/DailyMail Dataset

The CNN/DailyMail dataset contains news articles paired with human-written highlights. We'll use the articles as our source text and split the data into train, validation, and test sets.

**Dataset Structure:**
- `article`: The full news article text
- `highlights`: Human-written bullet-point summary
- We'll use the articles for summarization and evaluate our generated summaries

In [None]:
def load_and_split_data():
    """Load CNN/DailyMail dataset for summarization."""
    print("\n" + "=" * 80)
    print("üìÇ LOADING DATA")
    print("=" * 80)

    print("Loading CNN/DailyMail dataset...")
    dataset = load_dataset("abisee/cnn_dailymail", "3.0.0", trust_remote_code=True)
    data = dataset['test']

    print(f"‚úì Loaded {len(data)} examples")
    print(f"  Sample article: {data[0]['article'][:100]}...")
    print(f"  Sample highlights: {data[0]['highlights'][:100]}...")

    all_data = []
    for i, item in enumerate(data):
        all_data.append({
            'id': f"cnn_{i}",
            'text': item['article'],
            'reference_summary': item['highlights']
        })

    print(f"‚úì Converted to {len(all_data)} items")

    # Shuffle and split
    random.seed(RANDOM_SEED)
    random.shuffle(all_data)

    train_data = all_data[:TRAIN_SIZE]
    val_data = all_data[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]
    test_data = all_data[TRAIN_SIZE + VAL_SIZE:TRAIN_SIZE + VAL_SIZE + TEST_SIZE]

    print(f"‚úì Split: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}")

    # Verify
    assert len(val_data) > 0, "Val data is empty!"
    assert len(test_data) > 0, "Test data is empty!"

    return train_data, val_data, test_data

# Load the data
train_data, val_data, test_data = load_and_split_data()

## ü§ñ Summarization Module

We create a DSPy module that wraps our summarization task. This module can be configured with different instruction prompts, which is key to the GEPA optimization process.

In [None]:
class Summarizer(dspy.Signature):
    """Generate a summary."""
    text = dspy.InputField()
    summary = dspy.OutputField()


class SummarizationModule(dspy.Module):
    """Summarization module."""

    def __init__(self, instructions=None):
        super().__init__()
        self.instructions = instructions or BASELINE_PROMPT

        if instructions:
            class CustomSummarizer(dspy.Signature):
                __doc__ = instructions
                text = dspy.InputField()
                summary = dspy.OutputField()

            self.predictor = dspy.Predict(CustomSummarizer)
        else:
            self.predictor = dspy.Predict(Summarizer)

    def forward(self, text):
        return self.predictor(text=text)

print("‚úì Summarization module defined")

## üìä Batch Summary Generation

This function generates summaries for a batch of articles using a given prompt. It includes error handling and progress tracking.

In [None]:
def generate_summaries_batch(
        summarizer: SummarizationModule,
        data: List[Dict],
        desc: str = "Generating"
) -> List[Dict]:
    """Generate summaries for a batch of texts."""
    results = []
    errors = 0
    error_details = []

    # Print the prompt being used (first item only)
    if len(data) > 0:
        print(f"  Using prompt: {summarizer.instructions[:100]}...")

    for item in tqdm(data, desc=desc):
        try:
            pred = summarizer(text=item['text'][:5000])

            if pred is None:
                raise ValueError("Model returned None")

            if hasattr(pred, 'summary') and pred.summary:
                summary = pred.summary
            elif isinstance(pred, str):
                summary = pred
            else:
                print(f"\n  DEBUG: pred type={type(pred)}, hasattr summary={hasattr(pred, 'summary')}")
                raise ValueError(f"Cannot extract summary from {type(pred)}")

            summary = summary.strip()
            if len(summary) < 20:
                raise ValueError("Summary too short")

        except Exception as e:
            errors += 1
            error_details.append(str(e)[:100])

            if errors <= 5:
                print(f"\n‚ö†Ô∏è  Error: {str(e)[:80]}")

            summary = "Error generating summary."

        results.append({
            'id': item['id'],
            'text': item['text'],
            'summary': summary
        })

    if errors > 0:
        print(f"\n‚ö†Ô∏è  Total errors: {errors}/{len(data)} ({errors / len(data) * 100:.1f}%)")
        from collections import Counter
        common_errors = Counter(error_details).most_common(3)
        print(f"  Most common errors:")
        for err, count in common_errors:
            print(f"    - {err[:60]}... ({count}x)")

    return results

print("‚úì Batch generation function defined")

## üß† Optimizer LLM Wrapper

This wrapper allows us to use an LLM to propose improvements to our summarization prompt based on current performance.

In [None]:
class SimpleOptimizerLM:
    """Wrapper for optimizer LLM."""

    def __init__(self, model: str, api_key: str):
        self.client = together.Client(api_key=api_key)
        self.model = model

    def __call__(self, prompt: str) -> str:
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7,
            max_tokens=4000
        )
        return response.choices[0].message.content

print("‚úì Optimizer LLM wrapper defined")

## ü§î Reflection and Prompt Improvement

This function uses the optimizer LLM to analyze the current prompt and performance, then propose an improved version.

**Key Constraints:**
- Keep prompts under 150 words for clarity
- Focus on simple, direct instructions
- Target 4-6 sentence summaries
- Avoid overly complex requirements

In [None]:
def reflect_and_improve_prompt(
        current_prompt: str,
        current_score: float,
        optimizer_lm: SimpleOptimizerLM,
        iteration: int
) -> str:
    """Use LLM to propose improved prompt."""

    print(f"\nü§î REFLECTION (Iteration {iteration})")

    reflection_prompt = f"""You are optimizing a summarization prompt for CNN/DailyMail news articles.

Current Prompt:
```
{current_prompt}
```

Current Performance: {current_score:.1%} win rate

Your task: Propose a SIMPLE improved version that generates better summaries.

CRITICAL CONSTRAINTS:
- Keep the prompt under 150 words
- Make it clear and direct (NOT overly complex)
- Target 4-6 sentence summaries
- Avoid excessive instructions or formatting requirements
- The prompt should be easy for the model to follow

Focus on:
- Should it emphasize different aspects (accuracy, brevity, completeness)?
- Are the current guidelines clear?
- Is anything missing or unnecessary?

Output ONLY the improved prompt within ``` blocks. Keep it simple and clear."""

    response = optimizer_lm(reflection_prompt)

    # Extract prompt
    match = re.search(r'```(.*?)```', response, re.DOTALL)
    if match:
        new_prompt = match.group(1).strip()
        # Remove language tags
        for tag in ['markdown', 'text', 'python', 'plaintext']:
            if new_prompt.startswith(f'{tag}\n'):
                new_prompt = '\n'.join(new_prompt.split('\n')[1:])

        # Validate length (reject if too long)
        word_count = len(new_prompt.split())
        if word_count > 200:
            print(f"  ‚ö†Ô∏è  Generated prompt too long ({word_count} words), using current")
            return current_prompt

        print(f"‚úì Generated new prompt ({word_count} words)")
        return new_prompt

    print("‚ö†Ô∏è  Could not extract prompt")
    return current_prompt

print("‚úì Reflection function defined")

## üîÑ Head-to-Head Prompt Comparison

This function compares two prompts by:
1. Generating summaries with both prompts
2. Creating a comparison dataset
3. Using the Together AI evaluation API with a judge model
4. Computing win rates

The evaluation uses a two-pass approach to eliminate position bias.

In [None]:
def compare_two_prompts_on_batch(
        data: List[Dict],
        prompt_a: str,
        prompt_b: str,
        summarizer_lm: dspy.LM,
        eval_name: str
) -> Tuple[float, float, Dict]:
    """
    Compare two summarization prompts.

    1. Generate summaries with prompt A
    2. Generate summaries with prompt B
    3. Use judge to compare them
    4. Return win rate for prompt A
    """

    print(f"\n{'=' * 80}")
    print(f"üîÑ COMPARING PROMPTS: {eval_name}")
    print(f"{'=' * 80}")

    # Step 1: Generate with both prompts
    dspy.configure(lm=summarizer_lm)

    summarizer_a = SummarizationModule(prompt_a)
    summarizer_b = SummarizationModule(prompt_b)

    print("Generating summaries with Prompt A...")
    summaries_a = generate_summaries_batch(summarizer_a, data, "Prompt A")

    print("Generating summaries with Prompt B...")
    summaries_b = generate_summaries_batch(summarizer_b, data, "Prompt B")

    # Step 2: Prepare comparison data
    temp_file = f"temp_compare_{eval_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"

    with open(temp_file, 'w') as f:
        for summary_a, summary_b in zip(summaries_a, summaries_b):
            formatted = {
                "prompt": f"Source article: {summary_a['text'][:5000]}",
                "model_a_output": summary_a['summary'],
                "model_b_output": summary_b['summary'],
                "id": summary_a['id']
            }
            f.write(json.dumps(formatted) + '\n')

    # Step 3: Upload and evaluate
    print("üì§ Uploading for comparison...")
    file_response = client.files.upload(file=temp_file, purpose="eval")
    file_id = file_response.id

    print("üöÄ Launching comparison...")
    eval_response = client.evaluation.create(
        type="compare",
        input_data_file_path=file_id,
        judge_model=JUDGE_MODEL,
        judge_model_source="serverless",
        judge_system_template=JUDGE_PROMPT,
        model_a="model_a_output",
        model_b="model_b_output"
    )

    # Step 4: Wait and get results
    print(f"‚è≥ Waiting (ID: {eval_response.workflow_id})...")
    while True:
        status = client.evaluation.status(eval_response.workflow_id)
        if status.status.value == "completed":
            break
        elif status.status.value == "failed":
            raise Exception("Evaluation failed")
        time.sleep(30)

    a_wins = status.results.get('A_wins', 0)
    b_wins = status.results.get('B_wins', 0)
    ties = status.results.get('Ties', 0)

    # Win rate for prompt A
    decisive_total = a_wins + b_wins
    if decisive_total > 0:
        a_win_rate = a_wins / decisive_total
        b_win_rate = b_wins / decisive_total
    else:
        a_win_rate = b_win_rate = 0.5

    print(f"‚úì Results: Prompt A wins={a_wins}, Prompt B wins={b_wins}, Ties={ties}")
    print(f"‚úì Prompt A win rate: {a_win_rate:.2%}")

    os.remove(temp_file)

    return a_win_rate, b_win_rate, {
        'a_wins': a_wins,
        'b_wins': b_wins,
        'ties': ties,
        'a_win_rate': a_win_rate
    }

print("‚úì Comparison function defined")

## üß¨ GEPA Optimization Loop

This is the main optimization loop that implements the GEPA algorithm:

1. **Generate**: Create summaries with current prompt
2. **Evaluate**: Compare against baseline using judge model
3. **Propose**: Use optimizer LLM to suggest improvements
4. **Adapt**: Accept improvements that increase win rate

The process repeats for multiple iterations, tracking the best prompt found.

In [None]:
def run_manual_gepa(
        train_data: List[Dict],
        val_data: List[Dict],
        test_data: List[Dict],
        summarizer_lm: dspy.LM,
        optimizer_lm: SimpleOptimizerLM,
        max_iterations: int = 5
):
    """Manual GEPA-style optimization."""

    start_time = time.time()

    print("\n" + "=" * 80)
    print("üß¨ MANUAL GEPA OPTIMIZATION")
    print("=" * 80)

    # Track best prompt
    best_prompt = BASELINE_PROMPT
    best_val_score = 0.5  # Start at 50% (neutral)

    for i in range(max_iterations):
        print(f"\n{'=' * 80}")
        print(f"ITERATION {i + 1}/{max_iterations}")
        print(f"{'=' * 80}")

        if i == 0:
            print("Iteration 0: Establishing baseline (no comparison yet)")
            continue

        # Generate new candidate prompt
        new_prompt = reflect_and_improve_prompt(
            best_prompt,
            best_val_score,
            optimizer_lm,
            i
        )

        if new_prompt == best_prompt:
            print("‚ö†Ô∏è  No change in prompt, stopping")
            break

        print(f"‚úì Generated candidate prompt ({len(new_prompt)} chars)")

        # Compare best_prompt vs new_prompt on validation set
        baseline_win_rate, new_prompt_win_rate, metrics = compare_two_prompts_on_batch(
            val_data,
            prompt_a=best_prompt,
            prompt_b=new_prompt,
            summarizer_lm=summarizer_lm,
            eval_name=f"iter{i}_val"
        )

        new_prompt_win_rate = 1.0 - baseline_win_rate

        print(f"\n  Current best: {baseline_win_rate:.2%}")
        print(f"  New candidate: {new_prompt_win_rate:.2%}")

        if new_prompt_win_rate > best_val_score:
            improvement = new_prompt_win_rate - best_val_score
            print(f"  üéâ New best! (+{improvement * 100:.2f}pp)")
            best_prompt = new_prompt
            best_val_score = new_prompt_win_rate
        else:
            print(f"  No improvement")

    # Calculate total time
    total_time = time.time() - start_time
    hours = int(total_time // 3600)
    minutes = int((total_time % 3600) // 60)
    seconds = int(total_time % 60)

    # Final test evaluation
    print("\n" + "=" * 80)
    print("üìä FINAL TEST EVALUATION")
    print("=" * 80)

    print(f"\n‚è±Ô∏è  OPTIMIZATION TIME:")
    if hours > 0:
        print(f"  Total: {hours}h {minutes}m {seconds}s")
    elif minutes > 0:
        print(f"  Total: {minutes}m {seconds}s")
    else:
        print(f"  Total: {seconds}s")

    baseline_test_win_rate, optimized_test_win_rate, _ = compare_two_prompts_on_batch(
        test_data,
        prompt_a=BASELINE_PROMPT,
        prompt_b=best_prompt,
        summarizer_lm=summarizer_lm,
        eval_name="final_test"
    )

    # Display results
    print("\n" + "=" * 80)
    print("üéâ FINAL RESULTS")
    print("=" * 80)

    print(f"\nTEST SET:")
    print(f"  Baseline prompt:  {baseline_test_win_rate:.2%}")
    print(f"  Optimized prompt: {optimized_test_win_rate:.2%}")
    print(f"  Improvement:      {(optimized_test_win_rate - 0.5) * 100:+.2f}pp from neutral")

    # Save results
    output_dir = Path("results")
    output_dir.mkdir(exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    with open(output_dir / f"prompts_{timestamp}.txt", 'w') as f:
        f.write("BASELINE:\n" + "=" * 80 + "\n")
        f.write(BASELINE_PROMPT)
        f.write("\n\nOPTIMIZED:\n" + "=" * 80 + "\n")
        f.write(best_prompt)
        f.write(f"\n\nRESULTS:\n" + "=" * 80 + "\n")
        f.write(f"Baseline: {baseline_test_win_rate:.2%}\n")
        f.write(f"Optimized: {optimized_test_win_rate:.2%}\n")

    print(f"\nüíæ Saved to: results/prompts_{timestamp}.txt")

    return {
        'baseline_test': baseline_test_win_rate,
        'optimized_test': optimized_test_win_rate,
        'best_prompt': best_prompt
    }

print("‚úì GEPA optimization function defined")

## üöÄ Run the Optimization

Now we'll execute the full GEPA optimization process. This will:
1. Set up the summarizer and optimizer models
2. Run multiple iterations of prompt improvement
3. Evaluate the final optimized prompt on the test set
4. Display comprehensive results

In [None]:
print("="*80)
print("üéØ GEPA SUMMARIZATION - TOGETHER AI BATCH EVAL")
print("="*80)

if not TOGETHER_API_KEY or TOGETHER_API_KEY == 'your_api_key_here':
    print("‚ùå Set TOGETHER_API_KEY")
else:
    # Setup models
    summarizer_lm = dspy.LM(
        f"together_ai/{SUMMARIZER_MODEL}",
        api_key=TOGETHER_API_KEY,
        temperature=0.5,
        max_tokens=1024
    )

    optimizer_lm = SimpleOptimizerLM(
        model=OPTIMIZER_MODEL,
        api_key=TOGETHER_API_KEY
    )

    # Run optimization
    results = run_manual_gepa(
        train_data,
        val_data,
        test_data,
        summarizer_lm,
        optimizer_lm,
        max_iterations=5
    )

    print("\n‚úÖ Complete!")

## üìä Analyzing the Results

Let's examine the optimized prompt and compare it to the baseline.

In [None]:
print("=" * 80)
print("üìù PROMPT COMPARISON")
print("=" * 80)

print("\nBASELINE PROMPT:")
print("-" * 80)
print(BASELINE_PROMPT)

print("\n\nOPTIMIZED PROMPT:")
print("-" * 80)
print(results['best_prompt'])

print("\n\nPERFORMANCE COMPARISON:")
print("-" * 80)
print(f"Baseline Win Rate:  {results['baseline_test']:.2%}")
print(f"Optimized Win Rate: {results['optimized_test']:.2%}")
print(f"Improvement:        {(results['optimized_test'] - 0.5) * 100:+.2f} percentage points from neutral")

## üîë Key Findings

**GEPA Optimization Process:**
- Iteratively improves prompts through LLM-guided reflection
- Uses head-to-head comparisons with a judge model
- Tracks and accepts only improvements over baseline

**Benefits of This Approach:**
1. **Automated**: No manual prompt engineering required
2. **Data-driven**: Decisions based on actual performance metrics
3. **Scalable**: Can optimize for any task with appropriate data
4. **Transparent**: Clear tracking of improvements across iterations

**Next Steps:**
- Try with different datasets or domains
- Experiment with different judge criteria
- Adjust the optimizer's reflection prompt
- Increase iterations for potentially better results