## Import Libraries

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import Dataset, load_dataset
from tqdm import tqdm
import re
import random
import json
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Tuple
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch version: 2.9.0+cu126
CUDA available: True
CUDA device: Tesla T4
CUDA memory: 15.83 GB


## Configuration

In [4]:
@dataclass
class Config:
    """
    Complete configuration for ReflectEvo implementation
    Based on the original paper specifications
    """
    # Model settings
    model_name: str = "google/gemma-2-2b-it"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Dataset settings
    max_train_samples: int = 10  # ******(200) Paper uses full dataset (~7k), start small for testing ******
    max_test_samples: int = 5 # ******(100). ******

    # Stage 1: Solution Generation
    num_candidates_per_question: int = 5  # ******(5) Paper uses 5-10 candidates per question ******
    stage1_temperature: float = 0.8       # Higher temp for diversity
    stage1_top_p: float = 0.95
    stage1_max_tokens: int = 300

    # Stage 2: Reflection Generation
    stage2_temperature: float = 0.7
    stage2_max_tokens: int = 150

    # Stage 3: Refinement Generation
    stage3_temperature: float = 0.7
    stage3_max_tokens: int = 300

    # Evolutionary Algorithm Settings (THE "EVO" PART)
    num_generations: int = 3              # Paper uses 3-5 generations
    population_size: int = 8              # Number of prompt variants
    elite_ratio: float = 0.5              # Top 50% survive to next generation
    mutation_rate: float = 0.3            # 30% chance of mutation
    crossover_rate: float = 0.7           # 70% chance of crossover

    # Quality Filtering
    min_solution_quality: float = 0.3     # Keep solutions with quality > 0.3
    max_solutions_per_question: int = 3   # Diversity: max 3 solutions per question

    # Training settings
    num_epochs: int = 3
    batch_size: int = 4
    gradient_accumulation_steps: int = 4
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.1

    # Output directories
    output_dir_reflective: str = "./models/reflectevo_final"
    output_dir_standard: str = "./models/standard_baseline"
    checkpoint_dir: str = "./checkpoints"

config = Config()

print("="*80)
print("CONFIGURATION LOADED")
print("="*80)
print(f"Model: {config.model_name}")
print(f"Device: {config.device}")
print(f"Train samples: {config.max_train_samples}")
print(f"Test samples: {config.max_test_samples}")
print(f"Generations: {config.num_generations}")
print(f"Population size: {config.population_size}")
print(f"Candidates per question: {config.num_candidates_per_question}")
print("="*80)

CONFIGURATION LOADED
Model: google/gemma-2-2b-it
Device: cuda
Train samples: 10
Test samples: 5
Generations: 3
Population size: 8
Candidates per question: 5


## Utility Function

In [5]:
def extract_final_answer(text: str) -> str:
    """
    Extract the final numeric answer from a solution.
    Handles multiple formats: #### format, "Answer:", last number
    """
    # Priority 1: Look for #### format (GSM8K standard)
    match = re.search(r'####\s*(\d+\.?\d*)', text)
    if match:
        return match.group(1)

    # Priority 2: Look for "Answer:" format
    match = re.search(r'Answer:\s*\$?\s*(\d+\.?\d*)', text, re.IGNORECASE)
    if match:
        return match.group(1)

    # Priority 3: Look for "The answer is" format
    match = re.search(r'[Tt]he answer is\s*\$?\s*(\d+\.?\d*)', text)
    if match:
        return match.group(1)

    # Priority 4: Last number in text
    numbers = re.findall(r'\b(\d+\.?\d*)\b', text)
    if numbers:
        return numbers[-1]

    return None

def extract_gt_answer(gt_text: str) -> str:
    """Extract ground truth answer from GSM8K format"""
    match = re.search(r'####\s*(\d+\.?\d*)', gt_text)
    return match.group(1) if match else None

def is_correct(solution: str, gt_answer: str) -> bool:
    """
    Check if a solution produces the correct final answer.
    Returns True if predicted answer matches ground truth.
    """
    predicted = extract_final_answer(solution)
    ground_truth = extract_gt_answer(gt_answer)

    if predicted is None or ground_truth is None:
        return False

    # Handle both integer and float comparisons
    try:
        return abs(float(predicted) - float(ground_truth)) < 0.01
    except:
        return predicted == ground_truth

# Test the utility functions
print("Testing utility functions...")
test_cases = [
    ("The total is 72 clips. #### 72", "#### 72", True),
    ("So the answer is 50.", "#### 50", True),
    ("Therefore, 48 + 24 = 72", "#### 72", True),
    ("The result is 100", "#### 50", False),
]

for solution, gt, expected in test_cases:
    result = is_correct(solution, gt)
    status = "✅" if result == expected else "❌"
    print(f"  {status} is_correct('{solution[:30]}...', '{gt}') = {result}")

print("\n✅ Utility functions ready!")

Testing utility functions...
  ✅ is_correct('The total is 72 clips. #### 72...', '#### 72') = True
  ✅ is_correct('So the answer is 50....', '#### 50') = True
  ✅ is_correct('Therefore, 48 + 24 = 72...', '#### 72') = True
  ✅ is_correct('The result is 100...', '#### 50') = False

✅ Utility functions ready!


## Quality Scoring Criteria

In [6]:
def score_solution_quality(solution: str, gt_answer: str = None) -> float:
    """
    Score solution quality based on multiple criteria.
    Used for filtering low-quality solutions.

    Scoring criteria:
    1. Correctness (50%): Is the answer correct?
    2. Reasoning detail (30%): Length and detail of explanation
    3. Structure (20%): Has clear step-by-step organization

    Returns: Quality score between 0.0 and 1.0
    """
    total_score = 0.0

    # Criterion 1: Correctness (0.5 weight)
    if gt_answer is not None:
        correctness_score = 1.0 if is_correct(solution, gt_answer) else 0.0
        total_score += correctness_score * 0.5

    # Criterion 2: Reasoning detail (0.3 weight)
    # Prefer solutions between 50-250 words (not too short, not too verbose)
    word_count = len(solution.split())
    if word_count < 50:
        detail_score = word_count / 50  # Penalize very short solutions
    elif word_count <= 250:
        detail_score = 1.0  # Optimal length
    else:
        detail_score = max(0.5, 1.0 - (word_count - 250) / 250)  # Penalize verbose

    total_score += detail_score * 0.3

    # Criterion 3: Has clear structure (0.2 weight)
    # Check for step markers, clear organization
    structure_indicators = [
        'step', 'first', 'second', 'third', 'finally',
        'therefore', 'so', 'thus', 'next', 'then'
    ]

    solution_lower = solution.lower()
    structure_count = sum(1 for indicator in structure_indicators if indicator in solution_lower)
    structure_score = min(structure_count / 3, 1.0)  # At least 3 indicators = full score

    total_score += structure_score * 0.2

    return total_score

# Test quality scoring
print("Testing quality scoring...")
test_solutions = [
    ("First, we calculate 48/2 = 24 clips. Then, we add 48 + 24 = 72. #### 72", "#### 72"),
    ("72", "#### 72"),  # Too short
    ("Let me explain this in great detail with many words..." * 50, "#### 72"),  # Too long
]

for sol, gt in test_solutions:
    score = score_solution_quality(sol, gt)
    print(f"  Quality: {score:.2f} - {sol[:50]}...")

print("\n✅ Quality scoring ready!")

Testing quality scoring...
  Quality: 0.74 - First, we calculate 48/2 = 24 clips. Then, we add ...
  Quality: 0.51 - 72...
  Quality: 0.15 - Let me explain this in great detail with many word...

✅ Quality scoring ready!


## Diversity Filter Function

In [19]:
def filter_for_diversity(solutions: List[Dict], max_per_question: int = 3) -> List[Dict]:
    """
    Ensure diversity in training data by limiting solutions per question.
    Keeps the highest quality solutions for each question.

    This prevents the model from overfitting to specific question patterns.

    Args:
        solutions: List of solution dictionaries
        max_per_question: Maximum solutions to keep per unique question

    Returns:
        Filtered list with at most max_per_question solutions per question
    """
    # Group solutions by question
    by_question = defaultdict(list)
    for sol in solutions:
        by_question[sol['question']].append(sol)

    # For each question, keep only top N by quality
    diverse_solutions = []

    for question, question_solutions in by_question.items():
        # Sort by quality score (if available) or correctness
        sorted_sols = sorted(
            question_solutions,
            key=lambda x: x.get('quality_score', 0.5),
            reverse=True
        )

        # Keep top N
        diverse_solutions.extend(sorted_sols[:max_per_question])

    return diverse_solutions

print("✅ Diversity filter ready!")

✅ Diversity filter ready!


## Evaluation Prompt Optimiser

In [7]:
class PromptEvolutionEngine:
    """
    Evolutionary algorithm for optimizing reflection prompts.

    This is the core "Evo" component of ReflectEvo.

    Key operations:
    1. Mutation: Randomly modify prompt text
    2. Crossover: Combine elements from two parent prompts
    3. Selection: Keep best-performing prompts based on fitness
    4. Fitness evaluation: Measure how well prompts improve solutions

    The algorithm evolves prompts across generations to find the most
    effective reflection strategies.
    """

    def __init__(self, seed_prompts: List[str], population_size: int = 8):
        """
        Initialize the evolutionary engine.

        Args:
            seed_prompts: Initial prompt templates to start evolution
            population_size: Number of prompts in each generation
        """
        # Replicate seed prompts to fill population
        self.population = []
        while len(self.population) < population_size:
            self.population.extend(seed_prompts)
        self.population = self.population[:population_size]

        # Track fitness (performance) of each prompt
        self.fitness_scores = [0.0] * len(self.population)

        # Mutation vocabulary - words to swap in during mutation
        self.mutation_vocab = {
            'verbs': [
                'Identify', 'Analyze', 'Examine', 'Detect', 'Pinpoint',
                'Locate', 'Find', 'Discover', 'Uncover', 'Investigate'
            ],
            'adjectives': [
                'critical', 'key', 'important', 'significant', 'major',
                'fundamental', 'essential', 'crucial', 'primary', 'main'
            ],
            'error_types': [
                'computational errors', 'logical flaws', 'reasoning mistakes',
                'calculation errors', 'conceptual errors', 'procedural mistakes',
                'arithmetic errors', 'algebraic mistakes'
            ]
        }

        print(f"✅ Evolutionary engine initialized")
        print(f"   Population size: {len(self.population)}")
        print(f"   Seed prompts: {len(seed_prompts)}")

    def mutate(self, prompt: str) -> str:
        """
        Mutate a prompt by randomly replacing words/phrases.

        This introduces variation to explore new prompt strategies.
        """
        sentences = prompt.split('. ')

        if len(sentences) < 2:
            return prompt

        # Randomly select a sentence to mutate
        mutate_idx = random.randint(0, len(sentences) - 1)
        sentence = sentences[mutate_idx]

        # Apply random mutation type
        mutation_type = random.choice(['verb', 'adjective', 'error_type'])

        if mutation_type == 'verb':
            # Replace action verbs
            new_verb = random.choice(self.mutation_vocab['verbs'])
            words = sentence.split()
            if len(words) > 0:
                words[0] = new_verb
                sentences[mutate_idx] = ' '.join(words)

        elif mutation_type == 'adjective':
            # Insert or replace adjectives
            new_adj = random.choice(self.mutation_vocab['adjectives'])
            if 'error' in sentence:
                sentence = sentence.replace('error', f'{new_adj} error', 1)
                sentences[mutate_idx] = sentence

        elif mutation_type == 'error_type':
            # Replace error type mentions
            for old_error in self.mutation_vocab['error_types']:
                if old_error in sentence.lower():
                    new_error = random.choice(self.mutation_vocab['error_types'])
                    sentence = sentence.replace(old_error, new_error)
                    sentences[mutate_idx] = sentence
                    break

        return '. '.join(sentences)

    def crossover(self, parent1: str, parent2: str) -> str:
        """
        Combine two parent prompts to create offspring.

        Takes the beginning from parent1 and end from parent2,
        or vice versa, to explore combinations of strategies.
        """
        # Split both prompts into sentences
        sentences1 = [s.strip() for s in parent1.split('.') if s.strip()]
        sentences2 = [s.strip() for s in parent2.split('.') if s.strip()]

        if len(sentences1) < 2 or len(sentences2) < 2:
            return parent1  # Fallback if not enough structure

        # Crossover point (middle)
        point1 = len(sentences1) // 2
        point2 = len(sentences2) // 2

        # Create offspring: first half of parent1 + second half of parent2
        offspring = sentences1[:point1] + sentences2[point2:]

        return '. '.join(offspring) + '.'

    def evaluate_fitness(self, prompt_idx: int, stage3_results: List[Dict]) -> float:
        """
        Evaluate how well a prompt performs.

        Fitness = percentage of incorrect solutions that become correct
                  after using this prompt for reflection + refinement

        Args:
            prompt_idx: Index of prompt in population
            stage3_results: Results from Stage 3 (refined solutions)

        Returns:
            Fitness score between 0.0 and 1.0
        """
        # Filter results generated using this specific prompt
        results_for_prompt = [
            r for r in stage3_results
            if r.get('prompt_idx') == prompt_idx
        ]

        if not results_for_prompt:
            return 0.0

        # Calculate improvement rate
        improvements = sum(
            1 for item in results_for_prompt
            if item.get('is_refined_correct', False)
        )

        fitness = improvements / len(results_for_prompt)

        return fitness

    def select_parents(self) -> List[str]:
        """
        Select elite prompts (top performers) to be parents for next generation.

        Uses elitism: best prompts always survive.
        """
        # Sort population by fitness
        sorted_indices = sorted(
            range(len(self.fitness_scores)),
            key=lambda i: self.fitness_scores[i],
            reverse=True
        )

        # Keep top 50% (elite_ratio)
        elite_count = max(1, int(len(self.population) * config.elite_ratio))
        elite_indices = sorted_indices[:elite_count]

        elite_prompts = [self.population[i] for i in elite_indices]

        return elite_prompts

    def evolve_population(self) -> List[str]:
        """
        Create next generation using genetic operations.

        Process:
        1. Select elite prompts (best performers)
        2. Fill remaining population with:
           - Mutations of elite prompts (30%)
           - Crossovers between elite prompts (70%)

        Returns:
            New population for next generation
        """
        # Get elite prompts
        elite = self.select_parents()

        # Start new population with elite
        new_population = elite.copy()

        # Fill remaining slots
        while len(new_population) < len(self.population):
            if random.random() < config.mutation_rate:
                # Mutation: modify a random elite prompt
                parent = random.choice(elite)
                offspring = self.mutate(parent)
                new_population.append(offspring)
            else:
                # Crossover: combine two random elite prompts
                parent1, parent2 = random.sample(elite, 2)
                offspring = self.crossover(parent1, parent2)
                new_population.append(offspring)

        # Update population
        self.population = new_population[:len(self.population)]

        return self.population

    def get_best_prompt(self) -> Tuple[str, float]:
        """
        Get the best performing prompt and its fitness score.

        Returns:
            (best_prompt, fitness_score)
        """
        best_idx = max(range(len(self.fitness_scores)), key=lambda i: self.fitness_scores[i])
        return self.population[best_idx], self.fitness_scores[best_idx]

    def display_population(self):
        """Display current population and fitness scores"""
        print("\n" + "="*80)
        print("CURRENT PROMPT POPULATION")
        print("="*80)
        for i, (prompt, fitness) in enumerate(zip(self.population, self.fitness_scores)):
            print(f"\nPrompt {i+1} (Fitness: {fitness:.3f}):")
            print(f"  {prompt[:150]}...")
        print("="*80)

print("✅ Evolutionary engine class defined!")

✅ Evolutionary engine class defined!


## Seed Prompt

In [9]:
# These are the initial prompt templates that will evolve
# Based on the paper's approach to reflection

SEED_REFLECTION_PROMPTS = [
    """Analyze this mathematical solution carefully. Identify any computational errors, logical flaws, or incorrect reasoning steps. Focus on where the solution deviates from correct mathematical principles.""",

    """Examine the proposed solution for mistakes. Look for calculation errors, wrong assumptions, missing steps, or faulty logic. Explain what went wrong in the reasoning process.""",

    """Review this solution step-by-step. Pinpoint specific errors in arithmetic, algebra, or logical reasoning. Identify which step first introduces an incorrect result.""",

    """Investigate the solution for mathematical errors. Check each calculation and logical step. Determine where and why the solution fails to arrive at the correct answer.""",
]

print(f"✅ {len(SEED_REFLECTION_PROMPTS)} seed prompts defined")

✅ 4 seed prompts defined


## Initialise Evaluation Engine

In [10]:
# Create the evolutionary engine with seed prompts
prompt_evolution = PromptEvolutionEngine(
    seed_prompts=SEED_REFLECTION_PROMPTS,
    population_size=config.population_size
)

# Display initial population
prompt_evolution.display_population()

print(f"\n✅ Ready to evolve prompts across {config.num_generations} generations!")

✅ Evolutionary engine initialized
   Population size: 8
   Seed prompts: 4

CURRENT PROMPT POPULATION

Prompt 1 (Fitness: 0.000):
  Analyze this mathematical solution carefully. Identify any computational errors, logical flaws, or incorrect reasoning steps. Focus on where the solut...

Prompt 2 (Fitness: 0.000):
  Examine the proposed solution for mistakes. Look for calculation errors, wrong assumptions, missing steps, or faulty logic. Explain what went wrong in...

Prompt 3 (Fitness: 0.000):
  Review this solution step-by-step. Pinpoint specific errors in arithmetic, algebra, or logical reasoning. Identify which step first introduces an inco...

Prompt 4 (Fitness: 0.000):
  Investigate the solution for mathematical errors. Check each calculation and logical step. Determine where and why the solution fails to arrive at the...

Prompt 5 (Fitness: 0.000):
  Analyze this mathematical solution carefully. Identify any computational errors, logical flaws, or incorrect reasoning steps. Focus 

## Model Factory and Tokeniser

In [11]:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'  # Important for batch generation

def load_base_model():
    """
    Factory function to load a fresh copy of the base model.
    Used multiple times during training to avoid contamination.
    """
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    model.gradient_checkpointing_enable()
    return model

print("✅ Tokenizer loaded successfully")
print(f"   Vocab size: {len(tokenizer)}")
print(f"   Pad token: {tokenizer.pad_token}")
print(f"   EOS token: {tokenizer.eos_token}")
print("\n✅ Model factory function ready")

Loading tokenizer...


config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

✅ Tokenizer loaded successfully
   Vocab size: 256000
   Pad token: <eos>
   EOS token: <eos>

✅ Model factory function ready


## Load GSM8K Dataset

In [12]:
print("Loading GSM8K dataset...")
print("="*80)

# Load full GSM8K dataset
gsm_full = load_dataset("gsm8k", "main")

# Select subsets for faster experimentation
# For production: use full dataset (7473 train, 1319 test)
train_data = gsm_full["train"].select(range(config.max_train_samples))
test_data = gsm_full["test"].select(range(config.max_test_samples))

print(f"✅ Dataset loaded successfully")
print(f"   Train samples: {len(train_data)} / {len(gsm_full['train'])} total")
print(f"   Test samples: {len(test_data)} / {len(gsm_full['test'])} total")

# Display a sample
print("\n" + "="*80)
print("SAMPLE TRAINING EXAMPLE")
print("="*80)
sample = train_data[0]
print(f"Question:\n{sample['question']}\n")
print(f"Ground Truth:\n{sample['answer']}")
print("="*80)

Loading GSM8K dataset...


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

✅ Dataset loaded successfully
   Train samples: 10 / 7473 total
   Test samples: 5 / 1319 total

SAMPLE TRAINING EXAMPLE
Question:
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Ground Truth:
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72


## Stage 1 Generate Candidate Solutions

In [None]:
def generate_candidate_solutions(model, data, num_candidates: int = 5) -> List[Dict]:
    """
    STAGE 1: Generate multiple candidate solutions for each question.

    This stage creates diversity in solutions by sampling multiple answers
    per question. Some will be correct, many will be incorrect - the incorrect
    ones are valuable for learning through reflection.

    Args:
        model: The language model
        data: Dataset to generate solutions for
        num_candidates: Number of solution attempts per question

    Returns:
        List of dictionaries containing questions, solutions, and metadata
    """
    print("\n" + "="*80)
    print("STAGE 1: GENERATE CANDIDATE SOLUTIONS")
    print("="*80)
    print(f"Generating {num_candidates} candidates per question...")

    stage1_data = []

    for sample in tqdm(data, desc="Stage 1 Generation"):
        question = sample["question"].strip()
        gt_answer = sample["answer"].strip()

        # Create standardized prompt
        prompt = f"Question: {question}\nSolution:"
        input_ids = tokenizer(prompt, return_tensors="pt").to(config.device)

        # Generate multiple candidate solutions
        with torch.no_grad():
            outputs = model.generate(
                **input_ids,
                do_sample=True,  # Sampling for diversity
                temperature=config.stage1_temperature,
                top_p=config.stage1_top_p,
                num_return_sequences=num_candidates,
                max_new_tokens=config.stage1_max_tokens,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        # Extract each generated solution (remove prompt)
        for output in outputs:
            generated_tokens = output[input_ids["input_ids"].shape[-1]:]
            solution = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

            # Check correctness
            correct = is_correct(solution, gt_answer)

            # Calculate quality score
            quality = score_solution_quality(solution, gt_answer)

            stage1_data.append({
                "question": question,
                "solution": solution,
                "gt_answer": gt_answer,
                "is_correct": correct,
                "quality_score": quality
            })

    # Calculate statistics
    total = len(stage1_data)
    correct_count = sum(1 for item in stage1_data if item["is_correct"])
    avg_quality = sum(item["quality_score"] for item in stage1_data) / total

    print("\n" + "="*80)
    print("STAGE 1 RESULTS")
    print("="*80)
    print(f"Total solutions generated: {total}")
    print(f"Correct solutions: {correct_count} ({correct_count/total*100:.1f}%)")
    print(f"Incorrect solutions: {total-correct_count} ({(total-correct_count)/total*100:.1f}%)")
    print(f"Average quality score: {avg_quality:.3f}")
    print("="*80)

    return stage1_data

print("✅ Stage 1 function defined")

## Stage 2 - Generate Reflections with Evolved Prompts

In [14]:
def generate_reflections_with_evolution(
    model,
    stage1_data: List[Dict],
    prompt_evolution: PromptEvolutionEngine
) -> List[Dict]:
    """
    STAGE 2: Generate reflections using evolved prompts.

    For INCORRECT solutions only, generate reflections that identify
    what went wrong. Uses all prompts in the current population.

    This is where the "Evo" part happens - different prompts are tried
    and their effectiveness is tracked for evolution.

    Args:
        model: The language model
        stage1_data: Solutions from Stage 1
        prompt_evolution: The evolutionary engine with current prompts

    Returns:
        List of dictionaries with reflections and prompt tracking
    """
    print("\n" + "="*80)
    print("STAGE 2: GENERATE REFLECTIONS (WITH EVOLUTION)")
    print("="*80)

    # Filter: Keep ONLY incorrect solutions
    # Key insight: We learn from mistakes, not successes
    incorrect_solutions = [item for item in stage1_data if not item["is_correct"]]

    print(f"Filtered {len(incorrect_solutions)} incorrect solutions for reflection")
    print(f"Using {len(prompt_evolution.population)} evolved prompts\n")

    stage2_data = []

    # For each prompt in the population, generate reflections
    for prompt_idx, evolved_prompt in enumerate(prompt_evolution.population):
        print(f"Processing with Prompt {prompt_idx + 1}/{len(prompt_evolution.population)}...")

        # Use a subset of incorrect solutions per prompt (for efficiency)
        # In production, you might use all incorrect solutions
        solutions_subset = incorrect_solutions[:len(incorrect_solutions)]  # Use all for now

        for item in tqdm(solutions_subset, desc=f"  Prompt {prompt_idx+1}", leave=False):
            # Construct full reflection prompt
            full_prompt = f"""{evolved_prompt}

Question: {item['question']}

Proposed Solution: {item['solution']}

Reflection:"""

            inputs = tokenizer(full_prompt, return_tensors="pt").to(config.device)

            # Generate reflection
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    do_sample=True,
                    temperature=config.stage2_temperature,
                    max_new_tokens=config.stage2_max_tokens,
                    pad_token_id=tokenizer.eos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )

            # Extract reflection (remove prompt)
            generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:]
            reflection = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

            stage2_data.append({
                "question": item["question"],
                "solution": item["solution"],
                "gt_answer": item["gt_answer"],
                "reflection": reflection,
                "prompt_idx": prompt_idx,  # Track which prompt generated this
                "evolved_prompt": evolved_prompt
            })

    print("\n" + "="*80)
    print("STAGE 2 RESULTS")
    print("="*80)
    print(f"Total reflections generated: {len(stage2_data)}")
    print(f"Reflections per prompt: ~{len(stage2_data) // len(prompt_evolution.population)}")
    print("="*80)

    return stage2_data

print("✅ Stage 2 function defined")

✅ Stage 2 function defined


##  Stage 3 - Generate Refined Solutions

In [15]:
REFINEMENT_TEMPLATE = """Question: {question}

Initial Solution (incorrect):
{solution}

Reflection on errors:
{reflection}

Based on the reflection, provide a corrected solution with proper step-by-step reasoning:"""

def generate_refined_solutions(model, stage2_data: List[Dict]) -> List[Dict]:
    """
    STAGE 3: Generate refined/corrected solutions using reflections.

    Takes the original (wrong) solution and reflection, and generates
    a corrected solution. This teaches the model to learn from mistakes.

    Args:
        model: The language model
        stage2_data: Data from Stage 2 (reflections)

    Returns:
        List of dictionaries with refined solutions and improvement metrics
    """
    print("\n" + "="*80)
    print("STAGE 3: GENERATE REFINED SOLUTIONS")
    print("="*80)

    stage3_data = []

    for item in tqdm(stage2_data, desc="Stage 3 Refinement"):
        # Construct refinement prompt
        prompt = REFINEMENT_TEMPLATE.format(
            question=item["question"],
            solution=item["solution"],
            reflection=item["reflection"]
        )

        inputs = tokenizer(prompt, return_tensors="pt").to(config.device)

        # Generate refined solution
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                do_sample=True,
                temperature=config.stage3_temperature,
                top_p=0.95,
                max_new_tokens=config.stage3_max_tokens,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        # Extract refined solution
        generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:]
        refined_solution = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

        # Check if refinement fixed the solution
        is_refined_correct = is_correct(refined_solution, item["gt_answer"])

        # Calculate quality of refined solution
        refined_quality = score_solution_quality(refined_solution, item["gt_answer"])

        stage3_data.append({
            "question": item["question"],
            "original_solution": item["solution"],
            "reflection": item["reflection"],
            "refined_solution": refined_solution,
            "gt_answer": item["gt_answer"],
            "is_refined_correct": is_refined_correct,
            "refined_quality": refined_quality,
            "prompt_idx": item["prompt_idx"],  # Track prompt performance
        })

    # Calculate improvement statistics
    improved = sum(1 for item in stage3_data if item["is_refined_correct"])
    avg_quality = sum(item["refined_quality"] for item in stage3_data) / len(stage3_data)

    print("\n" + "="*80)
    print("STAGE 3 RESULTS")
    print("="*80)
    print(f"Refined solutions generated: {len(stage3_data)}")
    print(f"Now correct after refinement: {improved} ({improved/len(stage3_data)*100:.1f}%)")
    print(f"Average refined quality: {avg_quality:.3f}")
    print("="*80)

    return stage3_data

print("✅ Stage 3 function defined")

✅ Stage 3 function defined


## Load the Model

In [16]:
print("Loading base model for generation...")
base_model = load_base_model()

Loading base model for generation...


`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/288 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

## Multi Generation training Loop

In [17]:
def run_reflectevo_pipeline(num_generations: int = 3):
    """
    Run the complete ReflectEvo pipeline with evolutionary optimization.

    This is the MAIN ALGORITHM from the paper.

    For each generation:
    1. Generate candidate solutions (Stage 1)
    2. Generate reflections with current prompts (Stage 2)
    3. Generate refined solutions (Stage 3)
    4. Evaluate prompt fitness
    5. Evolve prompts using genetic algorithm
    6. Accumulate high-quality training data

    After all generations:
    7. Train final model on accumulated data

    Args:
        num_generations: Number of evolutionary generations

    Returns:
        Tuple of (all_training_data, final_model)
    """
    print("\n" + "#"*80)
    print("#" + " "*78 + "#")
    print("#" + " "*20 + "REFLECTEVO FULL PIPELINE" + " "*34 + "#")
    print("#" + " "*78 + "#")
    print("#"*80)
    print(f"\nRunning {num_generations} generations of evolution...")
    print(f"Starting with {len(prompt_evolution.population)} prompts in population\n")

    # Load base model for data generation
    print("Loading base model for generation...")
    base_model = load_base_model()

    # Accumulate training data across all generations
    all_reflective_training_data = []

    # Store generation statistics
    generation_stats = []

    # ========================================================================
    # EVOLUTIONARY LOOP
    # ========================================================================
    for gen in range(num_generations):
        print("\n" + "█"*80)
        print(f"█  GENERATION {gen + 1}/{num_generations}")
        print("█"*80)

        gen_stats = {"generation": gen + 1}

        # ====================================================================
        # STAGE 1: Generate Candidate Solutions
        # ====================================================================
        stage1_data = generate_candidate_solutions(
            base_model,
            train_data,
            num_candidates=config.num_candidates_per_question
        )
        gen_stats["stage1_correct"] = sum(1 for x in stage1_data if x["is_correct"])
        gen_stats["stage1_total"] = len(stage1_data)

        # ====================================================================
        # STAGE 2: Generate Reflections with Evolved Prompts
        # ====================================================================
        stage2_data = generate_reflections_with_evolution(
            base_model,
            stage1_data,
            prompt_evolution
        )
        gen_stats["stage2_reflections"] = len(stage2_data)

        # ====================================================================
        # STAGE 3: Generate Refined Solutions
        # ====================================================================
        stage3_data = generate_refined_solutions(base_model, stage2_data)
        gen_stats["stage3_improved"] = sum(1 for x in stage3_data if x["is_refined_correct"])
        gen_stats["stage3_total"] = len(stage3_data)

        # ====================================================================
        # EVALUATE FITNESS: How well did each prompt perform?
        # ====================================================================
        print("\n" + "="*80)
        print("EVALUATING PROMPT FITNESS")
        print("="*80)

        for prompt_idx in range(len(prompt_evolution.population)):
            fitness = prompt_evolution.evaluate_fitness(prompt_idx, stage3_data)
            prompt_evolution.fitness_scores[prompt_idx] = fitness
            print(f"  Prompt {prompt_idx + 1}: Fitness = {fitness:.3f}")

        # Get best prompt this generation
        best_prompt, best_fitness = prompt_evolution.get_best_prompt()
        gen_stats["best_fitness"] = best_fitness

        print(f"\n✅ Best prompt this generation: Fitness = {best_fitness:.3f}")
        print(f"   Preview: {best_prompt[:100]}...")

        # ====================================================================
        # PREPARE TRAINING DATA: Filter and combine high-quality solutions
        # ====================================================================
        print("\n" + "="*80)
        print("PREPARING TRAINING DATA")
        print("="*80)

        # Collect correct solutions from Stage 1
        correct_from_stage1 = [
            item for item in stage1_data
            if item["is_correct"] and item["quality_score"] >= config.min_solution_quality
        ]

        # Collect improved solutions from Stage 3
        improved_from_stage3 = [
            item for item in stage3_data
            if item["is_refined_correct"] and item["refined_quality"] >= config.min_solution_quality
        ]

        print(f"Correct from Stage 1: {len(correct_from_stage1)}")
        print(f"Improved from Stage 3: {len(improved_from_stage3)}")

        # Format training examples
        gen_training_data = []

        # Add correct solutions
        for item in correct_from_stage1:
            gen_training_data.append({
                "text": f"Question: {item['question']}\nSolution: {item['solution']}",
                "quality_score": item["quality_score"]
            })

        # Add refined solutions
        for item in improved_from_stage3:
            gen_training_data.append({
                "text": f"Question: {item['question']}\nSolution: {item['refined_solution']}",
                "quality_score": item["refined_quality"]
            })

        # Apply diversity filter
        gen_training_data = filter_for_diversity(
            gen_training_data,
            max_per_question=config.max_solutions_per_question
        )

        print(f"After quality filtering & diversity: {len(gen_training_data)} examples")
        gen_stats["training_examples"] = len(gen_training_data)

        # Accumulate for final training
        all_reflective_training_data.extend(gen_training_data)

        # ====================================================================
        # EVOLVE PROMPTS: Create next generation (except on last generation)
        # ====================================================================
        if gen < num_generations - 1:
            print("\n" + "="*80)
            print("EVOLVING PROMPTS FOR NEXT GENERATION")
            print("="*80)

            prompt_evolution.evolve_population()
            print(f"✅ Prompts evolved!")
            print(f"   New population ready for Generation {gen + 2}")

            # Display evolved population
            prompt_evolution.display_population()

        # Save generation stats
        generation_stats.append(gen_stats)

        # Print generation summary
        print("\n" + "█"*80)
        print(f"█  GENERATION {gen + 1} COMPLETE")
        print("█"*80)
        print(f"   Stage 1: {gen_stats['stage1_correct']}/{gen_stats['stage1_total']} correct")
        print(f"   Stage 3: {gen_stats['stage3_improved']}/{gen_stats['stage3_total']} improved")
        print(f"   Training examples: {gen_stats['training_examples']}")
        print(f"   Best fitness: {gen_stats['best_fitness']:.3f}")
        print("█"*80)

    # ========================================================================
    # CLEAN UP
    # ========================================================================
    del base_model
    torch.cuda.empty_cache()
    print("\n🧹 Base model cleaned from memory")

    # ========================================================================
    # FINAL SUMMARY
    # ========================================================================
    print("\n" + "#"*80)
    print("#" + " "*78 + "#")
    print("#" + " "*25 + "EVOLUTION COMPLETE" + " "*35 + "#")
    print("#" + " "*78 + "#")
    print("#"*80)

    print(f"\nTotal generations: {num_generations}")
    print(f"Total training examples accumulated: {len(all_reflective_training_data)}")

    # Show fitness progression
    print("\nFitness progression across generations:")
    for stat in generation_stats:
        print(f"  Gen {stat['generation']}: Best Fitness = {stat['best_fitness']:.3f}")

    # Get final best prompt
    final_best_prompt, final_best_fitness = prompt_evolution.get_best_prompt()
    print(f"\n✅ Final best prompt (Fitness: {final_best_fitness:.3f}):")
    print(f"   {final_best_prompt}")

    return all_reflective_training_data, generation_stats

print("✅ ReflectEvo pipeline function defined!")

✅ ReflectEvo pipeline function defined!


## Run the ReflectEvo Pipeline

In [18]:
# This is where the magic happens!
# Run the complete ReflectEvo pipeline with evolutionary optimization

print("🚀 Starting ReflectEvo pipeline...")
print("="*80)

reflective_training_data, generation_stats = run_reflectevo_pipeline(
    num_generations=config.num_generations
)

print("\n✅ ReflectEvo pipeline completed successfully!")
print(f"   Generated {len(reflective_training_data)} high-quality training examples")

🚀 Starting ReflectEvo pipeline...

################################################################################
#                                                                              #
#                    REFLECTEVO FULL PIPELINE                                  #
#                                                                              #
################################################################################

Running 3 generations of evolution...
Starting with 8 prompts in population

Loading base model for generation...


Loading weights:   0%|          | 0/288 [00:00<?, ?it/s]


████████████████████████████████████████████████████████████████████████████████
█  GENERATION 1/3
████████████████████████████████████████████████████████████████████████████████

STAGE 1: GENERATE CANDIDATE SOLUTIONS
Generating 5 candidates per question...


Stage 1 Generation: 100%|██████████| 10/10 [01:48<00:00, 10.86s/it]



STAGE 1 RESULTS
Total solutions generated: 50
Correct solutions: 34 (68.0%)
Incorrect solutions: 16 (32.0%)
Average quality score: 0.704

STAGE 2: GENERATE REFLECTIONS (WITH EVOLUTION)
Filtered 16 incorrect solutions for reflection
Using 8 evolved prompts

Processing with Prompt 1/8...




Processing with Prompt 2/8...




Processing with Prompt 3/8...




Processing with Prompt 4/8...




Processing with Prompt 5/8...




Processing with Prompt 6/8...




Processing with Prompt 7/8...




Processing with Prompt 8/8...





STAGE 2 RESULTS
Total reflections generated: 128
Reflections per prompt: ~16

STAGE 3: GENERATE REFINED SOLUTIONS


Stage 3 Refinement: 100%|██████████| 128/128 [15:22<00:00,  7.21s/it]


STAGE 3 RESULTS
Refined solutions generated: 128
Now correct after refinement: 16 (12.5%)
Average refined quality: 0.397

EVALUATING PROMPT FITNESS
  Prompt 1: Fitness = 0.125
  Prompt 2: Fitness = 0.188
  Prompt 3: Fitness = 0.062
  Prompt 4: Fitness = 0.250
  Prompt 5: Fitness = 0.062
  Prompt 6: Fitness = 0.062
  Prompt 7: Fitness = 0.188
  Prompt 8: Fitness = 0.062

✅ Best prompt this generation: Fitness = 0.250
   Preview: Investigate the solution for mathematical errors. Check each calculation and logical step. Determine...

PREPARING TRAINING DATA
Correct from Stage 1: 34
Improved from Stage 3: 16





NameError: name 'filter_for_diversity' is not defined

## Prepare Standard Baseline Training Dataset

In [None]:
def prepare_standard_baseline_data(train_data):
    """
    Prepare standard training data (no reflection/refinement).

    This is the BASELINE to compare against ReflectEvo.
    Just uses ground truth solutions directly.

    Args:
        train_data: Original training dataset

    Returns:
        List of training examples
    """
    print("\n" + "="*80)
    print("PREPARING STANDARD BASELINE DATA")
    print("="*80)

    standard_training_data = []

    for sample in train_data:
        # Use ground truth answer directly
        gt_solution = sample["answer"].strip()

        standard_training_data.append({
            "text": f"Question: {sample['question']}\nSolution: {gt_solution}"
        })

    print(f"✅ Standard baseline data prepared")
    print(f"   Training examples: {len(standard_training_data)}")
    print("="*80)

    return standard_training_data

# Prepare baseline data
standard_training_data = prepare_standard_baseline_data(train_data)

## Training Function

In [None]:
def tokenize_function(examples):
    """Tokenize text examples for training"""
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length",
    )

def train_model(train_examples: List[Dict], output_dir: str, model_name: str = "Model"):
    """
    Train a model on prepared training data.

    Args:
        train_examples: List of training dictionaries with 'text' field
        output_dir: Where to save the trained model
        model_name: Name for logging

    Returns:
        Trained model
    """
    print("\n" + "="*80)
    print(f"TRAINING: {model_name}")
    print("="*80)

    # Load fresh model
    print("Loading fresh base model...")
    model = load_base_model()

    # Convert to HuggingFace Dataset
    print("Converting to dataset...")
    train_dataset = Dataset.from_list(train_examples)

    print(f"Training examples: {len(train_dataset)}")

    # Tokenize
    print("Tokenizing...")
    tokenized_dataset = train_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=train_dataset.column_names,
        desc="Tokenizing"
    )

    # Add labels (for causal LM, labels = input_ids)
    def add_labels(examples):
        examples["labels"] = examples["input_ids"].copy()
        return examples

    tokenized_dataset = tokenized_dataset.map(add_labels, batched=True)

    print(f"✅ Dataset prepared: {len(tokenized_dataset)} examples")

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=config.num_epochs,
        per_device_train_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        warmup_ratio=config.warmup_ratio,
        logging_steps=10,
        save_steps=500,
        save_total_limit=2,
        fp16=True,
        dataloader_num_workers=0,
        report_to="none",
        logging_dir=f"{output_dir}/logs",
        remove_unused_columns=False,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
    )

    # Train
    print(f"\n🚀 Starting training ({config.num_epochs} epochs)...\n")
    trainer.train()

    # Save model
    print(f"\n💾 Saving model to {output_dir}...")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    print(f"✅ {model_name} training complete!")
    print("="*80)

    return model

print("✅ Training functions defined")

## Train ReflectEvo Model

In [None]:
# Train the ReflectEvo model on evolved, high-quality data

print("\n" + "#"*80)
print("#" + " "*20 + "TRAINING REFLECTEVO MODEL" + " "*34 + "#")
print("#"*80)

reflectevo_model = train_model(
    train_examples=reflective_training_data,
    output_dir=config.output_dir_reflective,
    model_name="ReflectEvo (Evolutionary)"
)

print("\n✅ ReflectEvo model training complete!")

## Train Standard Baseline Model

In [None]:
# Clean up memory
del reflectevo_model
torch.cuda.empty_cache()
print("🧹 Cleared ReflectEvo model from memory\n")

# Train the standard baseline model

print("\n" + "#"*80)
print("#" + " "*20 + "TRAINING STANDARD BASELINE" + " "*33 + "#")
print("#"*80)

standard_model = train_model(
    train_examples=standard_training_data,
    output_dir=config.output_dir_standard,
    model_name="Standard Baseline (No Reflection)"
)

print("\n✅ Standard baseline training complete!")

## Evaluation Function

In [None]:
def evaluate_model_on_testset(model, test_data, model_name: str = "Model"):
    """
    Evaluate a trained model on the test set.

    For each test question:
    1. Generate a solution
    2. Extract the final answer
    3. Compare with ground truth

    Args:
        model: Trained model to evaluate
        test_data: Test dataset
        model_name: Name for logging

    Returns:
        Tuple of (accuracy, detailed_results)
    """
    print("\n" + "="*80)
    print(f"EVALUATION: {model_name}")
    print("="*80)

    model.eval()  # Set to evaluation mode

    correct = 0
    total = len(test_data)
    results = []

    for sample in tqdm(test_data, desc=f"Evaluating {model_name}"):
        question = sample["question"].strip()
        gt_answer = sample["answer"].strip()

        # Generate solution
        prompt = f"Question: {question}\nSolution:"
        inputs = tokenizer(prompt, return_tensors="pt").to(config.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                do_sample=False,  # Greedy decoding for deterministic evaluation
                max_new_tokens=256,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        # Extract generated solution
        generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:]
        solution = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

        # Check correctness
        is_correct_answer = is_correct(solution, gt_answer)

        if is_correct_answer:
            correct += 1

        results.append({
            "question": question,
            "generated_solution": solution,
            "gt_answer": gt_answer,
            "predicted_answer": extract_final_answer(solution),
            "gt_final_answer": extract_gt_answer(gt_answer),
            "is_correct": is_correct_answer
        })

    accuracy = (correct / total) * 100

    print("\n" + "="*80)
    print(f"{model_name} RESULTS")
    print("="*80)
    print(f"Correct: {correct}/{total}")
    print(f"Accuracy: {accuracy:.2f}%")
    print("="*80)

    return accuracy, results

print("✅ Evaluation function defined")

##  Evaluate ReflectEvo Model

In [None]:
# Load and evaluate ReflectEvo model

print("\n📊 Loading ReflectEvo model for evaluation...")
reflectevo_model = AutoModelForCausalLM.from_pretrained(
    config.output_dir_reflective,
    torch_dtype=torch.float16,
    device_map="auto"
)

reflectevo_accuracy, reflectevo_results = evaluate_model_on_testset(
    reflectevo_model,
    test_data,
    model_name="ReflectEvo (Evolutionary)"
)

# Clean up
del reflectevo_model
torch.cuda.empty_cache()

## Evaluate Standard Baseline Model

In [None]:
# Load and evaluate Standard model

print("\n📊 Loading Standard Baseline model for evaluation...")
standard_model = AutoModelForCausalLM.from_pretrained(
    config.output_dir_standard,
    torch_dtype=torch.float16,
    device_map="auto"
)

standard_accuracy, standard_results = evaluate_model_on_testset(
    standard_model,
    test_data,
    model_name="Standard Baseline"
)

# Clean up
del standard_model
torch.cuda.empty_cache()

## Final Comparison & Analysis

In [None]:
print("\n" + "█"*80)
print("█" + " "*78 + "█")
print("█" + " "*25 + "FINAL RESULTS" + " "*41 + "█")
print("█" + " "*78 + "█")
print("█"*80)

print("\n" + "="*80)
print("ACCURACY COMPARISON")
print("="*80)
print(f"ReflectEvo (Evolutionary):  {reflectevo_accuracy:.2f}%")
print(f"Standard Baseline:          {standard_accuracy:.2f}%")
print(f"{'─'*80}")
print(f"Absolute Improvement:       {reflectevo_accuracy - standard_accuracy:+.2f}%")

if standard_accuracy > 0:
    relative_improvement = ((reflectevo_accuracy - standard_accuracy) / standard_accuracy) * 100
    print(f"Relative Improvement:       {relative_improvement:+.2f}%")

print("="*80)

# Determine winner
if reflectevo_accuracy > standard_accuracy:
    print("\n🏆 WINNER: ReflectEvo")
    print("   ✅ Evolutionary reflection significantly improves performance!")
elif reflectevo_accuracy < standard_accuracy:
    print("\n⚠️  WINNER: Standard Baseline")
    print("   ⚠️  ReflectEvo underperformed (may need more training data/generations)")
else:
    print("\n🤝 TIE")
    print("   Both methods perform equally")

# Calculate agreement/disagreement
agreements = sum(
    1 for r1, r2 in zip(reflectevo_results, standard_results)
    if r1["is_correct"] == r2["is_correct"]
)

print(f"\nAgreement on {agreements}/{len(test_data)} test cases ({agreements/len(test_data)*100:.1f}%)")

# Find cases where ReflectEvo succeeds but Standard fails
reflectevo_wins = [
    i for i, (r1, r2) in enumerate(zip(reflectevo_results, standard_results))
    if r1["is_correct"] and not r2["is_correct"]
]

standard_wins = [
    i for i, (r1, r2) in enumerate(zip(reflectevo_results, standard_results))
    if r2["is_correct"] and not r1["is_correct"]
]

print(f"\nReflectEvo wins: {len(reflectevo_wins)} cases")
print(f"Standard wins: {len(standard_wins)} cases")

print("\n" + "█"*80)

## Detailed Example Comparisons (Optional)

In [None]:
# Show specific examples where the models differ

print("\n" + "="*80)
print("DETAILED EXAMPLE COMPARISONS")
print("="*80)

num_examples = min(3, len(reflectevo_wins))

if num_examples > 0:
    print(f"\n{'─'*80}")
    print("CASES WHERE REFLECTEVO SUCCEEDS BUT STANDARD FAILS")
    print(f"{'─'*80}")

    for idx in reflectevo_wins[:num_examples]:
        print(f"\nExample {idx + 1}:")
        print(f"{'─'*80}")

        r_evo = reflectevo_results[idx]
        r_std = standard_results[idx]

        print(f"Question:")
        print(f"  {r_evo['question'][:150]}...")

        print(f"\nGround Truth Answer: {r_evo['gt_final_answer']}")

        print(f"\nReflectEvo Solution:")
        print(f"  {r_evo['generated_solution'][:200]}...")
        print(f"  Predicted: {r_evo['predicted_answer']} ✅")

        print(f"\nStandard Solution:")
        print(f"  {r_std['generated_solution'][:200]}...")
        print(f"  Predicted: {r_std['predicted_answer']} ❌")

        print(f"{'─'*80}")

else:
    print("\nNo cases where ReflectEvo uniquely succeeds.")

print("\n" + "="*80)

## Visualize Results (Optional)

In [None]:
import matplotlib.pyplot as plt

# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Accuracy Comparison
models = ['ReflectEvo', 'Standard']
accuracies = [reflectevo_accuracy, standard_accuracy]
colors = ['#4CAF50', '#2196F3']

axes[0].bar(models, accuracies, color=colors, alpha=0.8, edgecolor='black')
axes[0].set_ylabel('Accuracy (%)', fontsize=12)
axes[0].set_title('Model Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].set_ylim(0, 100)
axes[0].grid(axis='y', alpha=0.3)

for i, (model, acc) in enumerate(zip(models, accuracies)):
    axes[0].text(i, acc + 2, f'{acc:.1f}%', ha='center', fontweight='bold')

# Plot 2: Fitness Evolution Across Generations
if generation_stats:
    generations = [stat['generation'] for stat in generation_stats]
    best_fitness = [stat['best_fitness'] for stat in generation_stats]

    axes[1].plot(generations, best_fitness, marker='o', linewidth=2, markersize=8, color='#FF5722')
    axes[1].set_xlabel('Generation', fontsize=12)
    axes[1].set_ylabel('Best Prompt Fitness', fontsize=12)
    axes[1].set_title('Prompt Evolution Progress', fontsize=14, fontweight='bold')
    axes[1].grid(alpha=0.3)
    axes[1].set_ylim(0, 1)

plt.tight_layout()
plt.savefig('reflectevo_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("✅ Visualization saved as reflectevo_results.png")

## Save Results to JSON (Optional)

In [None]:
import json
from datetime import datetime

# Compile all results
results_summary = {
    "timestamp": datetime.now().isoformat(),
    "configuration": {
        "model_name": config.model_name,
        "train_samples": config.max_train_samples,
        "test_samples": config.max_test_samples,
        "num_generations": config.num_generations,
        "population_size": config.population_size,
        "num_candidates": config.num_candidates_per_question,
        "num_epochs": config.num_epochs,
    },
    "reflectevo": {
        "accuracy": reflectevo_accuracy,
        "correct": sum(1 for r in reflectevo_results if r["is_correct"]),
        "total": len(reflectevo_results),
        "training_examples": len(reflective_training_data),
    },
    "standard": {
        "accuracy": standard_accuracy,
        "correct": sum(1 for r in standard_results if r["is_correct"]),
        "total": len(standard_results),
        "training_examples": len(standard_training_data),
    },
    "comparison": {
        "absolute_improvement": reflectevo_accuracy - standard_accuracy,
        "relative_improvement": ((reflectevo_accuracy - standard_accuracy) / standard_accuracy * 100) if standard_accuracy > 0 else 0,
        "reflectevo_unique_wins": len(reflectevo_wins),
        "standard_unique_wins": len(standard_wins),
    },
    "generation_stats": generation_stats,
}

# Save to file
output_file = "reflectevo_experiment_results.json"
with open(output_file, "w") as f:
    json.dump(results_summary, f, indent=2)

print(f"✅ Results saved to {output_file}")

# Display summary
print("\n" + "="*80)
print("EXPERIMENT SUMMARY")
print("="*80)
print(json.dumps(results_summary, indent=2))
print("="*80)