In [1]:
### Evaluation on medqa

def test_medqa(ensemble_model, baseline_model, ensemble_tokenizer, baseline_tokenizer, dataset, output_dir, device = 'cpu'):
    """
    Evaluate performance of ensemble model vs baseline model and random guessing on MedQA dataset.
    
    Args:
        ensemble_model: Your custom ensemble model
        baseline_model: The baseline huggingface model for comparison
        dataset: MedQA formatted dataset with 'question' and 'answer' columns
        output_dir: Directory to save results and visualizations
        
    Returns:
        dict: Dictionary containing all metrics and results
    """
    import torch
    import re
    import numpy as np
    import random
    from tqdm import tqdm
    import json
    import os
    import matplotlib.pyplot as plt
    from collections import Counter
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Ensure models are in evaluation mode
    ensemble_model.eval() # <------------------------------------------------------------------------- TODO
    baseline_model.eval()
    
    # Helper function to extract the answer letter from generated text
    def extract_answer(text):
        # Look for patterns like "The correct answer is: X)" or just "X)"
        match = re.search(r'(?:The correct answer is:?\s*)?([A-E])\)', text)
        if match:
            letter = match.group(1).upper()
            return letter
            
        # If no match with the above pattern, try a simpler pattern to find any letter
        letter_match = re.search(r'\b([A-E])\b', text)
        if letter_match:
            letter = letter_match.group(1).upper()
            return letter
            
        return None
    
    # Initialize metrics
    metrics = {
        "total_questions": len(dataset),
        "ensemble": {
            "correct": 0,
            "letter_found": 0,
            "generated_lengths": [],
            "letter_distribution": Counter()
        },
        "baseline": {
            "correct": 0,
            "letter_found": 0,
            "generated_lengths": [],
            "letter_distribution": Counter()
        },
        "random": {
            "correct": 0,
            "letter_distribution": Counter()
        },
        "examples": []
    }
    
    # Generation parameters optimized for multiple-choice answer generation
    gen_params = {
        "max_new_tokens": 512,         # Enough for the answer format
        "num_return_sequences": 1,
        "temperature": 0.3,             # Lower temperature for more deterministic outputs
        "top_p": 0.92,                  # Slightly increased for better coverage
        "top_k": 40,                    # Restrict to only the most likely tokens
        "num_beams": 3,                 # Use beam search for better results
        "no_repeat_ngram_size": 2,      # Prevent repetition
        "repetition_penalty": 1.2       # Penalize repeated tokens
    }

    example_indices = random.sample(range(len(dataset)), min(5, len(dataset)))
    
    # Process all test examples
    print(f"Testing on {len(dataset)} examples...")
    
    for i, example in enumerate(tqdm(dataset, desc="Testing models")):
        # Extract the question and correct answer
        question = example['question']
        correct_answer = example['answer']
        
        # Extract correct answer letter
        correct_letter = extract_answer(correct_answer)
        if not correct_letter:
            print(f"Warning: Could not extract letter from correct answer: {correct_answer}")
            correct_letter = "Unknown"  # Fall back
        
        # Format prompt (just use the question as is)
        prompt = f'I want you to reason and answer on the following question.\n{question}\n\nStart by reasoning about the question, then provide the answer in the format "The answer is X) <answer>\n Reasoning:"'
        
        # Track if this is an example to save
        is_example = i in example_indices
        example_data = {"question": question, "correct_answer": correct_answer} if is_example else None
        
        # Generate random choice - simulate random guessing baseline
        # Assume A, B, C, D, E are possible options
        available_options = ["A", "B", "C", "D", "E"]
        random_letter = random.choice(available_options)
        
        # Record random baseline metrics
        metrics["random"]["letter_distribution"][random_letter] += 1
        if random_letter == correct_letter:
            metrics["random"]["correct"] += 1
        
        # Test ensemble model
        try:
            # Generate the answer with the ensemble model
            ensemble_inputs = ensemble_tokenizer(prompt, return_tensors="pt").to(device)
            
            with torch.no_grad():
                ensemble_output = ensemble_model.generate( # <-------------------------------------------------- TODO
                    **ensemble_inputs,
                    **gen_params
                )
            
            # Process output
            ensemble_generated = ensemble_tokenizer.decode(ensemble_output[0], skip_special_tokens=True)
            
            # Get the model's answer by taking the part after the original prompt
            ensemble_answer = ensemble_generated[len(prompt):].strip()
            
            # Extract the letter
            ensemble_letter = extract_answer(ensemble_answer)
            
            # Record metrics
            metrics["ensemble"]["generated_lengths"].append(len(ensemble_answer.split()))
            
            if ensemble_letter:
                metrics["ensemble"]["letter_found"] += 1
                metrics["ensemble"]["letter_distribution"][ensemble_letter] += 1
                if ensemble_letter == correct_letter:
                    metrics["ensemble"]["correct"] += 1
                    
            if is_example:
                example_data["ensemble_answer"] = ensemble_answer
                example_data["ensemble_letter"] = ensemble_letter
                
        except Exception as e:
            print(f"Error with ensemble model: {e}")
            if is_example:
                example_data["ensemble_answer"] = f"Error: {str(e)}"
                example_data["ensemble_letter"] = None
        
        # Test baseline model
        try:
            # Generate the answer with the baseline model
            baseline_inputs = baseline_tokenizer(prompt, return_tensors="pt").to(device)
            
            with torch.no_grad():
                baseline_output = baseline_model.generate(
                    **baseline_inputs,
                    **gen_params
                )
            
            # Process output
            baseline_generated = baseline_tokenizer.decode(baseline_output[0], skip_special_tokens=True)
            
            # Get the model's answer by taking the part after the original prompt
            baseline_answer = baseline_generated[len(prompt):].strip()
            
            # Extract the letter
            baseline_letter = extract_answer(baseline_answer)
            
            # Record metrics
            metrics["baseline"]["generated_lengths"].append(len(baseline_answer.split()))
            
            if baseline_letter:
                metrics["baseline"]["letter_found"] += 1
                metrics["baseline"]["letter_distribution"][baseline_letter] += 1
                if baseline_letter == correct_letter:
                    metrics["baseline"]["correct"] += 1
                    
            if is_example:
                example_data["baseline_answer"] = baseline_answer
                example_data["baseline_letter"] = baseline_letter
                
        except Exception as e:
            print(f"Error with baseline model: {e}")
            if is_example:
                example_data["baseline_answer"] = f"Error: {str(e)}"
                example_data["baseline_letter"] = None
        
        # Save the example if selected
        if is_example:
            example_data["random_letter"] = random_letter
            metrics["examples"].append(example_data)
    
    # Calculate overall metrics
    for model_type in ["ensemble", "baseline"]:
        # Calculate accuracy
        metrics[model_type]["accuracy"] = metrics[model_type]["correct"] / metrics["total_questions"]
        metrics[model_type]["letter_detection_rate"] = metrics[model_type]["letter_found"] / metrics["total_questions"]
        
        # Calculate average lengths
        metrics[model_type]["avg_generated_length"] = np.mean(metrics[model_type]["generated_lengths"]) if metrics[model_type]["generated_lengths"] else 0
        
        # Convert letter distribution to dictionary for JSON serialization
        metrics[model_type]["letter_distribution"] = dict(metrics[model_type]["letter_distribution"])
    
    # Calculate random baseline accuracy
    metrics["random"]["accuracy"] = metrics["random"]["correct"] / metrics["total_questions"]
    metrics["random"]["letter_detection_rate"] = 1.0  # Always detects a letter by definition
    metrics["random"]["letter_distribution"] = dict(metrics["random"]["letter_distribution"])
    
    # Calculate improvement over random and baseline
    metrics["improvement"] = {
        "over_random": metrics["ensemble"]["accuracy"] - metrics["random"]["accuracy"],
        "over_baseline": metrics["ensemble"]["accuracy"] - metrics["baseline"]["accuracy"],
    }
    
    # Print results
    print("\n" + "="*80)
    print("TESTING RESULTS SUMMARY")
    print("="*80)
    print(f"Total test questions: {metrics['total_questions']}")
    
    print("\nAccuracy (correct letter):")
    print(f"  Ensemble model: {metrics['ensemble']['correct']} / {metrics['total_questions']} = {metrics['ensemble']['accuracy']:.2%}")
    print(f"  Baseline model: {metrics['baseline']['correct']} / {metrics['total_questions']} = {metrics['baseline']['accuracy']:.2%}")
    print(f"  Random baseline: {metrics['random']['correct']} / {metrics['total_questions']} = {metrics['random']['accuracy']:.2%}")
    print(f"  Improvement over random: {metrics['improvement']['over_random']:.2%}")
    print(f"  Improvement over baseline: {metrics['improvement']['over_baseline']:.2%}")
    
    print("\nLetter detection rate:")
    print(f"  Ensemble model: {metrics['ensemble']['letter_found']} / {metrics['total_questions']} = {metrics['ensemble']['letter_detection_rate']:.2%}")
    print(f"  Baseline model: {metrics['baseline']['letter_found']} / {metrics['total_questions']} = {metrics['baseline']['letter_detection_rate']:.2%}")
    print(f"  Random baseline: 100% (by definition)")
    
    print("\nAverage generated length (words):")
    print(f"  Ensemble model: {metrics['ensemble']['avg_generated_length']:.2f}")
    print(f"  Baseline model: {metrics['baseline']['avg_generated_length']:.2f}")
    
    print("\nLetter distribution (ensemble model):")
    for letter, count in sorted(metrics['ensemble']['letter_distribution'].items()):
        percentage = count/metrics['ensemble']['letter_found'] if metrics['ensemble']['letter_found'] > 0 else 0
        print(f"  {letter}: {count} ({percentage:.2%})")
    
    print("\nLetter distribution (baseline model):")
    for letter, count in sorted(metrics['baseline']['letter_distribution'].items()):
        percentage = count/metrics['baseline']['letter_found'] if metrics['baseline']['letter_found'] > 0 else 0
        print(f"  {letter}: {count} ({percentage:.2%})")
    
    print("\nLetter distribution (random baseline):")
    total_random = sum(metrics['random']['letter_distribution'].values())
    for letter, count in sorted(metrics['random']['letter_distribution'].items()):
        percentage = count/total_random if total_random > 0 else 0
        print(f"  {letter}: {count} ({percentage:.2%})")
    
    # Print example comparison
    print("\n" + "="*80)
    print("EXAMPLE COMPARISONS")
    print("="*80)
    
    for i, example in enumerate(metrics["examples"]):
        print(f"\nExample {i+1}:")
        print(f"Question:\n{example['question']}")
        print(f"Correct answer: {example['correct_answer']}")
        print(f"\nEnsemble model answer:\n{example.get('ensemble_answer', 'N/A')}")
        print(f"Ensemble letter: {example.get('ensemble_letter', 'N/A')}")
        print(f"\nBaseline model answer:\n{example.get('baseline_answer', 'N/A')}")
        print(f"Baseline letter: {example.get('baseline_letter', 'N/A')}")
        print(f"\nRandom baseline letter: {example.get('random_letter', 'N/A')}")
        print("-"*40)
    
    # Create and save visualizations
    try:
        plt.figure(figsize=(15, 10))
        
        # Accuracy comparison
        plt.subplot(2, 2, 1)
        models = ['Ensemble', 'Baseline', 'Random']
        accuracies = [metrics['ensemble']['accuracy'], metrics['baseline']['accuracy'], metrics['random']['accuracy']]
        colors = ['green', 'blue', 'red']
        plt.bar(models, accuracies, color=colors)
        plt.title('Accuracy Comparison')
        plt.ylabel('Accuracy')
        plt.ylim(0, max(max(accuracies) * 1.2, 1.0))  # Set y-limit with headroom
        
        # Add accuracy values on top of bars
        for i, v in enumerate(accuracies):
            plt.text(i, v + 0.01, f"{v:.2%}", ha='center', fontweight='bold')
        
        # Letter detection rate
        plt.subplot(2, 2, 2)
        detection_rates = [
            metrics['ensemble']['letter_detection_rate'], 
            metrics['baseline']['letter_detection_rate'], 
            1.0  # Random always has a letter by definition
        ]
        plt.bar(models, detection_rates, color=colors)
        plt.title('Letter Detection Rate')
        plt.ylabel('Rate')
        plt.ylim(0, 1.1)  # Set y-limit with some headroom
        
        # Letter distribution for ensemble model
        plt.subplot(2, 2, 3)
        ensemble_letters = sorted(metrics['ensemble']['letter_distribution'].keys())
        ensemble_counts = [metrics['ensemble']['letter_distribution'].get(letter, 0) for letter in ensemble_letters]
        plt.bar(ensemble_letters, ensemble_counts, color='green')
        plt.title('Ensemble Model Letter Distribution')
        plt.ylabel('Count')
        
        # Letter distribution comparison (all models)
        plt.subplot(2, 2, 4)
        
        # Get all unique letters across all models
        all_letters = sorted(set(
            list(metrics['ensemble']['letter_distribution'].keys()) +
            list(metrics['baseline']['letter_distribution'].keys()) +
            list(metrics['random']['letter_distribution'].keys())
        ))
        
        # Width of each bar
        bar_width = 0.25
        
        # Set position of bars on x axis
        r1 = np.arange(len(all_letters))
        r2 = [x + bar_width for x in r1]
        r3 = [x + bar_width for x in r2]
        
        # Create bars
        ensemble_counts = [metrics['ensemble']['letter_distribution'].get(letter, 0) for letter in all_letters]
        baseline_counts = [metrics['baseline']['letter_distribution'].get(letter, 0) for letter in all_letters]
        random_counts = [metrics['random']['letter_distribution'].get(letter, 0) for letter in all_letters]
        
        plt.bar(r1, ensemble_counts, width=bar_width, label='Ensemble', color='green')
        plt.bar(r2, baseline_counts, width=bar_width, label='Baseline', color='blue')
        plt.bar(r3, random_counts, width=bar_width, label='Random', color='red')
        
        plt.xlabel('Letter')
        plt.ylabel('Count')
        plt.title('Letter Distribution Comparison')
        plt.xticks([r + bar_width for r in range(len(all_letters))], all_letters)
        plt.legend()
        
        plt.tight_layout()
        
        # Save chart
        chart_file = os.path.join(output_dir, 'performance_chart.png')
        plt.savefig(chart_file)
        print(f"Chart saved to {chart_file}")
        
        plt.close()
    except Exception as e:
        print(f"Could not create visualization: {e}")
    
    # Save detailed results
    results_file = os.path.join(output_dir, 'results.json')
    with open(results_file, 'w') as f:
        json.dump(metrics, f, indent=2)
    print(f"Detailed results saved to {results_file}")
    
    return metrics

In [2]:
def test_gsm8k(ensemble_model, baseline_model, ensemble_tokenizer, baseline_tokenizer, dataset, output_dir, device='cpu'):
    """
    Evaluate performance of ensemble model vs baseline model on GSM8K math reasoning dataset.
    
    Args:
        ensemble_model: Your custom ensemble model
        baseline_model: The baseline huggingface model for comparison
        ensemble_tokenizer: Tokenizer for the ensemble model
        baseline_tokenizer: Tokenizer for the baseline model
        dataset: GSM8K dataset with 'question' and 'answer' columns
        output_dir: Directory to save results and visualizations
        sample_size: Optional, number of random samples to use (default: use all)
        device: Device to run models on ('cuda', 'mps', or 'cpu')
        
    Returns:
        dict: Dictionary containing all metrics and results
    """
    import torch
    import re
    import numpy as np
    import random
    from tqdm import tqdm
    import json
    import os
    import matplotlib.pyplot as plt
    from collections import Counter, defaultdict
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Ensure models are in evaluation mode
    ensemble_model.eval()
    baseline_model.eval()
    
    # Helper function to extract the numeric answer from generated text
    def extract_answer(text):
        """
        Returns a numeric answer from formats like:
        - "Therefore, the answer is: # <number>"
        - "# <number>"
        
        Returns:
            float or int or None: The extracted answer (number or None if no match)
        """
        # Try to match number patterns with "Therefore, the answer is: # <number>"
        number_match = re.search(r'[Tt]herefore,\s+the\s+answer\s+is:?\s+#\s*([-+]?\d*\.?\d+)', text)
        if number_match:
            number_str = number_match.group(1)
            # Convert to float or int as appropriate
            return float(number_str) if '.' in number_str else int(number_str)
        
        # Try to match simpler number pattern "# <number>"
        simple_number_match = re.search(r'#\s*([-+]?\d*\.?\d+)', text)
        if simple_number_match:
            number_str = simple_number_match.group(1)
            # Convert to float or int as appropriate
            return float(number_str) if '.' in number_str else int(number_str)
        
        # Try to match "The answer is: <number>"
        answer_is_match = re.search(r'[Tt]he\s+answer\s+is:?\s*([-+]?\d*\.?\d+)', text)
        if answer_is_match:
            number_str = answer_is_match.group(1)
            return float(number_str) if '.' in number_str else int(number_str)
        
        return None
    
    # Initialize metrics
    metrics = {
        "total_questions": len(dataset),
        "ensemble": {
            "correct": 0,
            "answer_found": 0,
            "generated_lengths": [],
            "answer_distribution": Counter()
        },
        "baseline": {
            "correct": 0,
            "answer_found": 0,
            "generated_lengths": [],
            "answer_distribution": Counter()
        },
        "examples": []
    }
    
    # Generation parameters optimized for math reasoning
    gen_params = {
        "max_new_tokens": 512,          # Enough for detailed reasoning
        "num_return_sequences": 1,
        "temperature": 0.3,            # Lower temperature for more deterministic outputs
        "top_p": 0.92,                 # Slightly increased for better coverage
        "top_k": 40,                   # Restrict to only the most likely tokens
        "do_sample": True,             # Enable sampling for diversity
        "num_beams": 3,                # Use beam search for better results
        "repetition_penalty": 1.2      # Penalize repeated tokens
    }

    # Choose a few random examples for detailed analysis
    example_indices = random.sample(range(len(dataset)), min(5, len(dataset)))
    
    # Track error cases
    error_cases = {
        "ensemble": [],
        "baseline": []
    }
    
    # Process all test examples
    print(f"Testing on {len(dataset)} examples...")
    
    for i, example in enumerate(tqdm(dataset, desc="Testing models")):
        # Extract the question and correct answer
        question = example['question']
        correct_answer_text = example['answer']
        
        # Extract correct numeric answer from the reference answer
        correct_answer = extract_answer(correct_answer_text)
        if correct_answer is None:
            print(f"Warning: Could not extract numeric answer from: {correct_answer_text}")
            correct_answer = "Unknown"  # Fall back
        
        # Format prompt with chain-of-thought instruction
        prompt = f'I want you to solve the following math problem step by step.\n{question}\n\nStart by reasoning about the problem, then provide the numerical answer in the format "Therefore, the answer is: # <number>"'
        
        # Track if this is an example to save
        is_example = i in example_indices
        example_data = {"question": question, "correct_answer_text": correct_answer_text, "correct_answer": correct_answer} if is_example else None
        
        # Test ensemble model
        try:
            # Generate the answer with the ensemble model
            print(f"Testing ensemble model on example {i}...")
            ensemble_inputs = ensemble_tokenizer(prompt, return_tensors="pt").to(device)
            
            with torch.no_grad():
                ensemble_output = ensemble_model.generate(
                    **ensemble_inputs,
                    **gen_params
                )
            print(f"Ensemble model output: {ensemble_output}")
            # Process output
            ensemble_generated = ensemble_tokenizer.decode(ensemble_output[0], skip_special_tokens=True)
            print(f"Ensemble model generated: {ensemble_generated}")
            
            # Get the model's answer by taking the part after the original prompt
            ensemble_answer_text = ensemble_generated[len(prompt):].strip()
            print(f"Ensemble model answer text: {ensemble_answer_text}")
            # Extract the numeric answer
            ensemble_answer = extract_answer(ensemble_answer_text)
            
            # Record metrics
            metrics["ensemble"]["generated_lengths"].append(len(ensemble_answer_text.split()))
            
            if ensemble_answer is not None:
                metrics["ensemble"]["answer_found"] += 1
                metrics["ensemble"]["answer_distribution"][str(ensemble_answer)] += 1
                
                # Check if the answer is correct (allow for small floating point differences)
                if isinstance(ensemble_answer, (int, float)) and isinstance(correct_answer, (int, float)):
                    if abs(ensemble_answer - correct_answer) < 1e-6:  # For floating point comparison
                        metrics["ensemble"]["correct"] += 1
                elif ensemble_answer == correct_answer:  # Direct comparison for non-numeric
                    metrics["ensemble"]["correct"] += 1
                    
            if is_example:
                example_data["ensemble_answer_text"] = ensemble_answer_text
                example_data["ensemble_answer"] = ensemble_answer
                
        except Exception as e:
            print(f"Error with ensemble model on example {i}: {e}")
            error_cases["ensemble"].append({"index": i, "question": question, "error": str(e)})
            if is_example:
                example_data["ensemble_answer_text"] = f"Error: {str(e)}"
                example_data["ensemble_answer"] = None
        
        # Test baseline model
        try:
            # Generate the answer with the baseline model
            print(f"Testing baseline model on example {i}...")
            baseline_inputs = baseline_tokenizer(prompt, return_tensors="pt").to(device)
            
            with torch.no_grad():
                baseline_output = baseline_model.generate(
                    **baseline_inputs,
                    **gen_params
                )
            
            # Process output
            baseline_generated = baseline_tokenizer.decode(baseline_output[0], skip_special_tokens=True)
            
            # Get the model's answer by taking the part after the original prompt
            baseline_answer_text = baseline_generated[len(prompt):].strip()
            
            # Extract the numeric answer
            baseline_answer = extract_answer(baseline_answer_text)
            
            # Record metrics
            metrics["baseline"]["generated_lengths"].append(len(baseline_answer_text.split()))
            
            if baseline_answer is not None:
                metrics["baseline"]["answer_found"] += 1
                metrics["baseline"]["answer_distribution"][str(baseline_answer)] += 1
                
                # Check if the answer is correct (allow for small floating point differences)
                if isinstance(baseline_answer, (int, float)) and isinstance(correct_answer, (int, float)):
                    if abs(baseline_answer - correct_answer) < 1e-6:  # For floating point comparison
                        metrics["baseline"]["correct"] += 1
                elif baseline_answer == correct_answer:  # Direct comparison for non-numeric
                    metrics["baseline"]["correct"] += 1
                    
            if is_example:
                example_data["baseline_answer_text"] = baseline_answer_text
                example_data["baseline_answer"] = baseline_answer
                
        except Exception as e:
            print(f"Error with baseline model on example {i}: {e}")
            error_cases["baseline"].append({"index": i, "question": question, "error": str(e)})
            if is_example:
                example_data["baseline_answer_text"] = f"Error: {str(e)}"
                example_data["baseline_answer"] = None
        
        # Save the example if selected
        if is_example:
            metrics["examples"].append(example_data)
    
    # Calculate overall metrics
    for model_type in ["ensemble", "baseline"]:
        # Calculate accuracy
        metrics[model_type]["accuracy"] = metrics[model_type]["correct"] / metrics["total_questions"]
        metrics[model_type]["answer_detection_rate"] = metrics[model_type]["answer_found"] / metrics["total_questions"]
        
        # Calculate average lengths
        metrics[model_type]["avg_generated_length"] = np.mean(metrics[model_type]["generated_lengths"]) if metrics[model_type]["generated_lengths"] else 0
        
        # Convert answer distribution to dictionary for JSON serialization
        # Sort by frequency for better visualization
        sorted_dist = dict(sorted(metrics[model_type]["answer_distribution"].items(), 
                                  key=lambda x: x[1], reverse=True)[:20])  # Keep top 20 most frequent
        metrics[model_type]["answer_distribution"] = sorted_dist
    
    # Calculate improvement over baseline
    metrics["improvement"] = {
        "over_baseline": metrics["ensemble"]["accuracy"] - metrics["baseline"]["accuracy"],
    }
    
    # Add error statistics
    metrics["errors"] = {
        "ensemble": len(error_cases["ensemble"]),
        "baseline": len(error_cases["baseline"])
    }
    
    # Print results
    print("\n" + "="*80)
    print("TESTING RESULTS SUMMARY")
    print("="*80)
    print(f"Total test questions: {metrics['total_questions']}")
    
    print("\nAccuracy (correct numeric answer):")
    print(f"  Ensemble model: {metrics['ensemble']['correct']} / {metrics['total_questions']} = {metrics['ensemble']['accuracy']:.2%}")
    print(f"  Baseline model: {metrics['baseline']['correct']} / {metrics['total_questions']} = {metrics['baseline']['accuracy']:.2%}")
    print(f"  Improvement over baseline: {metrics['improvement']['over_baseline']:.2%}")
    
    print("\nAnswer detection rate:")
    print(f"  Ensemble model: {metrics['ensemble']['answer_found']} / {metrics['total_questions']} = {metrics['ensemble']['answer_detection_rate']:.2%}")
    print(f"  Baseline model: {metrics['baseline']['answer_found']} / {metrics['total_questions']} = {metrics['baseline']['answer_detection_rate']:.2%}")
    
    print("\nAverage generated length (words):")
    print(f"  Ensemble model: {metrics['ensemble']['avg_generated_length']:.2f}")
    print(f"  Baseline model: {metrics['baseline']['avg_generated_length']:.2f}")
    
    print("\nErrors:")
    print(f"  Ensemble model: {metrics['errors']['ensemble']} errors")
    print(f"  Baseline model: {metrics['errors']['baseline']} errors")
    
    print("\nTop 5 most common answers (ensemble model):")
    for i, (answer, count) in enumerate(list(metrics['ensemble']['answer_distribution'].items())[:5]):
        percentage = count/metrics['ensemble']['answer_found'] if metrics['ensemble']['answer_found'] > 0 else 0
        print(f"  {answer}: {count} ({percentage:.2%})")
    
    print("\nTop 5 most common answers (baseline model):")
    for i, (answer, count) in enumerate(list(metrics['baseline']['answer_distribution'].items())[:5]):
        percentage = count/metrics['baseline']['answer_found'] if metrics['baseline']['answer_found'] > 0 else 0
        print(f"  {answer}: {count} ({percentage:.2%})")
    
    # Print example comparison
    print("\n" + "="*80)
    print("EXAMPLE COMPARISONS")
    print("="*80)
    
    for i, example in enumerate(metrics["examples"]):
        print(f"\nExample {i+1}:")
        print(f"Question:\n{example['question']}")
        print(f"Correct answer text: {example['correct_answer_text']}")
        print(f"Correct numeric answer: {example['correct_answer']}")
        
        print(f"\nEnsemble model answer:")
        print(f"{example.get('ensemble_answer_text', 'N/A')}")
        print(f"Extracted answer: {example.get('ensemble_answer', 'N/A')}")
        
        print(f"\nBaseline model answer:")
        print(f"{example.get('baseline_answer_text', 'N/A')}")
        print(f"Extracted answer: {example.get('baseline_answer', 'N/A')}")
        
        print("-"*80)
    
    # Create and save visualizations
    try:
        plt.figure(figsize=(15, 10))
        
        # Accuracy comparison
        plt.subplot(2, 2, 1)
        models = ['Ensemble', 'Baseline']
        accuracies = [metrics['ensemble']['accuracy'], metrics['baseline']['accuracy']]
        colors = ['green', 'blue']
        plt.bar(models, accuracies, color=colors)
        plt.title('Accuracy Comparison')
        plt.ylabel('Accuracy')
        plt.ylim(0, max(max(accuracies) * 1.2, 1.0))  # Set y-limit with headroom
        
        # Add accuracy values on top of bars
        for i, v in enumerate(accuracies):
            plt.text(i, v + 0.01, f"{v:.2%}", ha='center', fontweight='bold')
        
        # Answer detection rate
        plt.subplot(2, 2, 2)
        detection_rates = [
            metrics['ensemble']['answer_detection_rate'], 
            metrics['baseline']['answer_detection_rate']
        ]
        plt.bar(models, detection_rates, color=colors)
        plt.title('Answer Detection Rate')
        plt.ylabel('Rate')
        plt.ylim(0, 1.1)  # Set y-limit with some headroom
        
        # Answer distribution for ensemble model (top 10)
        plt.subplot(2, 2, 3)
        top_ensemble_answers = list(metrics['ensemble']['answer_distribution'].keys())[:10]
        ensemble_counts = [metrics['ensemble']['answer_distribution'][answer] for answer in top_ensemble_answers]
        
        # Create a horizontal bar chart for better readability of answer values
        plt.barh(range(len(top_ensemble_answers)), ensemble_counts, color='green')
        plt.yticks(range(len(top_ensemble_answers)), top_ensemble_answers)
        plt.title('Top 10 Ensemble Model Answers')
        plt.xlabel('Count')
        
        # Answer distribution for baseline model (top 10)
        plt.subplot(2, 2, 4)
        top_baseline_answers = list(metrics['baseline']['answer_distribution'].keys())[:10]
        baseline_counts = [metrics['baseline']['answer_distribution'][answer] for answer in top_baseline_answers]
        
        # Create a horizontal bar chart
        plt.barh(range(len(top_baseline_answers)), baseline_counts, color='blue')
        plt.yticks(range(len(top_baseline_answers)), top_baseline_answers)
        plt.title('Top 10 Baseline Model Answers')
        plt.xlabel('Count')
        
        plt.tight_layout()
        
        # Save chart
        chart_file = os.path.join(output_dir, 'gsm8k_performance_chart.png')
        plt.savefig(chart_file)
        print(f"Chart saved to {chart_file}")
        
        plt.close()
        
        # Additional visualization: Compare most common answers between models
        plt.figure(figsize=(12, 8))
        
        # Combine top answers from both models
        combined_answers = set(list(metrics['ensemble']['answer_distribution'].keys())[:5] + 
                              list(metrics['baseline']['answer_distribution'].keys())[:5])
        
        # Create a mapping for position on x-axis
        x_positions = np.arange(len(combined_answers))
        
        # Width of each bar
        bar_width = 0.35
        
        # Get counts for each model
        ensemble_counts = [metrics['ensemble']['answer_distribution'].get(answer, 0) for answer in combined_answers]
        baseline_counts = [metrics['baseline']['answer_distribution'].get(answer, 0) for answer in combined_answers]
        
        # Create grouped bars
        plt.bar(x_positions - bar_width/2, ensemble_counts, bar_width, label='Ensemble', color='green')
        plt.bar(x_positions + bar_width/2, baseline_counts, bar_width, label='Baseline', color='blue')
        
        plt.xlabel('Answer')
        plt.ylabel('Count')
        plt.title('Most Common Answers Comparison')
        plt.xticks(x_positions, list(combined_answers))
        plt.legend()
        
        plt.tight_layout()
        
        # Save chart
        comparison_chart_file = os.path.join(output_dir, 'gsm8k_answer_comparison.png')
        plt.savefig(comparison_chart_file)
        print(f"Answer comparison chart saved to {comparison_chart_file}")
        
        plt.close()
        
    except Exception as e:
        print(f"Could not create visualization: {e}")
    
    # Save detailed results and error cases
    results_file = os.path.join(output_dir, 'gsm8k_results.json')
    with open(results_file, 'w') as f:
        # Create a copy that includes a sample of error cases (not all to keep file size reasonable)
        save_metrics = metrics.copy()
        save_metrics["error_cases"] = {
            "ensemble": error_cases["ensemble"][:10],  # Save the first 10 errors
            "baseline": error_cases["baseline"][:10]
        }
        json.dump(save_metrics, f, indent=2)
    print(f"Detailed results saved to {results_file}")
    
    return metrics

In [8]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import random
from src.architecture import MultiModelWithScalarHeads
from src.models import MODELS
from tqdm import tqdm

full_medqa_test = load_dataset("cola13/medqa_formatted", split="test")

# Create a random subset of 20 samples
random_indices = random.sample(range(len(full_medqa_test)), 1)
medqa_test = full_medqa_test.select(random_indices)

print(f"Selected 20 random samples from {len(full_medqa_test)} total examples")

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

ensamble_model = MultiModelWithScalarHeads.from_pretrained(MODELS, "exp1/best_model_heads.pt")
finetuned_model = AutoModelForCausalLM.from_pretrained("axel-datos/Llama-3.2-1B_gsm8k_full-finetuning")

ensemble_tokenizer = ensamble_model.get_tokenizer()
finetuned_tokenizer = AutoTokenizer.from_pretrained("axel-datos/Llama-3.2-1B_gsm8k_full-finetuning")

test_medqa(ensamble_model, finetuned_model, ensemble_tokenizer, finetuned_tokenizer, medqa_test, "medqa_results", device)

2025-05-16 22:06:29,201 - src.model_loader - INFO - All models unloaded
2025-05-16 22:06:29,202 - src.model_loader - INFO - ModelLoader instance deleted
2025-05-16 22:06:29,202 - src.model_loader - INFO - Loading tokenizer for Llama-3.2-1B General (LoRA) from meta-llama/Llama-3.2-1B


Selected 20 random samples from 1272 total examples
Loading ensemble model on mps...


2025-05-16 22:06:34,080 - src.model_loader - INFO - Loading model Llama-3.2-1B General (LoRA) from meta-llama/Llama-3.2-1B


KeyboardInterrupt: 

In [4]:
full_gsm8k_test = load_dataset("cola13/gsm8k_formatted", split="test")

# Create a random subset of 20 samples
random_indices = random.sample(range(len(full_medqa_test)), 20)
gsm8k_test = full_gsm8k_test.select(random_indices)

print(f"Selected 20 random samples from {len(full_gsm8k_test)} total examples")

test_gsm8k(general_model, other_model, general_tokenizer, other_tokenizer, gsm8k_test, "gsm8k_results", device)

Selected 20 random samples from 1319 total examples
Testing on 20 examples...


Testing models:   0%|          | 0/20 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Testing ensemble model on example 0...


: 