In [1]:
"""
Evaluate Fine-tuned Llama2 on MCQ Dataset
Using LoRA adapter directly (no merge needed)
Handles variable number of options (2-7+)
"""

import torch
import re
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm


In [2]:
#============================================================================
# CONFIGURATION
# ============================================================================

# Path to your trained LoRA adapter
ADAPTER_PATH = "llama2-mcq-best"  # Your best model

# Dataset file
TEST_FILE = "test_finetune.jsonl"  # or "test_finetune.jsonl"

# Evaluation settings
MAX_SAMPLES = None  # Set to None to evaluate all samples
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("="*80)
print("LOADING MODEL AND TOKENIZER")
print("="*80)

LOADING MODEL AND TOKENIZER


In [3]:
# ============================================================================
# LOAD MODEL WITH LORA ADAPTER (NO MERGE!)
# ============================================================================
print(f"Loading LoRA adapter from: {ADAPTER_PATH}")

model = AutoPeftModelForCausalLM.from_pretrained(
    ADAPTER_PATH,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"✓ Model loaded on: {model.device}")
print(f"✓ Model type: {type(model)}")
print("="*80 + "\n")

Loading LoRA adapter from: llama2-mcq-best




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Model loaded on: cuda:0
✓ Model type: <class 'peft.peft_model.PeftModelForCausalLM'>



In [4]:
# ============================================================================
# LOAD TEST DATASET
# ============================================================================
print(f"Loading test dataset from: {TEST_FILE}")
test_dataset = load_dataset("json", data_files=TEST_FILE, split="train")
print(f"✓ Loaded {len(test_dataset)} samples\n")


Loading test dataset from: test_finetune.jsonl


Generating train split: 0 examples [00:00, ? examples/s]

✓ Loaded 2512 samples



In [5]:
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def extract_valid_options(choices_str):
    """
    Extract valid option letters from the choices string
    Handles formats like: "(A) text (B) text" or "A. text B. text"
    Returns a set of valid letters: {'A', 'B', 'C', ...}
    """
    # Find all option letters (A-Z) that appear to be option labels
    # Look for patterns like "(A)", "A.", "A)", or standalone "A"
    pattern = r'\(([A-Z])\)|\b([A-Z])[\.\)]|\b([A-Z])\s'
    matches = re.findall(pattern, choices_str.upper())
    
    # Flatten the matches (regex groups) and remove empty strings
    valid_options = set()
    for match_tuple in matches:
        for letter in match_tuple:
            if letter:
                valid_options.add(letter)
    
    # If no pattern found, try to find any uppercase letters A-Z
    if not valid_options:
        valid_options = set(re.findall(r'\b([A-Z])\b', choices_str.upper()))
    
    return valid_options

In [8]:
def extract_answer_from_response(response, valid_options):
    """
    Extract the answer letter from model response
    Args:
        response: The model's generated text
        valid_options: Set of valid option letters for this question
    Returns:
        Extracted letter or None
    
    Priority order:
    1. If response is just one letter → return it
    2. Look for first valid letter in the response
    3. Handle common formats (A), [A], A. as fallback
    """
    response_upper = response.upper().strip()
    
    # Method 1: Response is just a single letter (ideal case)
    response_cleaned = re.sub(r'[^\w]', '', response_upper)
    if len(response_cleaned) == 1 and response_cleaned in valid_options:
        return response_cleaned
    
    # Method 2: First valid letter in response (works for "A" or "(A)" or "The answer is A")
    for char in response_upper:
        if char in valid_options:
            return char
    
    # Method 3: No valid option found
    return None

In [9]:
# ============================================================================
# EVALUATION FUNCTION
# ============================================================================
def evaluate_mcq(model, tokenizer, dataset, max_samples=None):
    """
    Evaluate the model on MCQ dataset with variable number of options
    
    Args:
        model: The fine-tuned model with LoRA adapter
        tokenizer: The tokenizer
        dataset: Dataset with Context, question, answerChoices, correctAnswer
        max_samples: Maximum number of samples to evaluate (None = all)
    
    Returns:
        accuracy: Accuracy score
        results: List of dictionaries with predictions and ground truth
    """
    correct = 0
    total = len(dataset) if max_samples is None else min(len(dataset), max_samples)
    results = []
    
    # Track option distribution
    option_counts = {}
    
    model.eval()  # Set to evaluation mode
    
    print(f"Evaluating on {total} samples...")
    print("="*80)
    
    with torch.no_grad():  # Disable gradient calculation for faster inference
        for i in tqdm(range(total), desc="Evaluating"):
            example = dataset[i]
            context = example["Context"]
            question = example["question"]
            choices = example["answerChoices"]
            correct_answer = example["correctAnswer"].upper().strip()
            
            # Extract valid options for this specific question
            valid_options = extract_valid_options(choices)
            num_options = len(valid_options)
            
            # Track option distribution
            option_counts[num_options] = option_counts.get(num_options, 0) + 1
            
            # SMART TRUNCATION: Truncate context if needed, but keep instruction intact
            max_context_tokens = 1800  # Reserve tokens for instruction + question + options
            
            # Tokenize context separately to check length
            context_tokens = tokenizer(context, add_special_tokens=False)['input_ids']
            
            # If context is too long, truncate it
            if len(context_tokens) > max_context_tokens:
                context_tokens = context_tokens[:max_context_tokens]
                context = tokenizer.decode(context_tokens, skip_special_tokens=True)
                context = context + "... [context truncated]"
            
            # Format prompt - same as training format but with clearer instructions
            prompt = f"""<s>[INST]<<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible using the context text provided. Your answer must be exactly one letter corresponding to the correct option. Do not include any explanation, punctuation, or extra text.
<</SYS>>

Context: {context}
Question: {question}
Options: {choices}
Answer: [/INST]"""
            
            # Tokenize (should not need truncation now, but keep as safety)
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
            
            # CRITICAL CHECK: Ensure [/INST] is in the tokenized input
            decoded_check = tokenizer.decode(inputs['input_ids'][0])
            if "[/INST]" not in decoded_check:
                # Fallback: Use shorter context
                context_short = context[:500] + "... [truncated]"
                prompt = f"""<s>[INST]<<SYS>>
You are a helpful assistant.
<</SYS>>

Context: {context_short}
Question: {question}
Options: {choices}
Answer: [/INST]"""
                inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
            
            # Generate
            outputs = model.generate(
                **inputs,
                max_new_tokens=5,  # Very short - we only need one letter
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            
            # Decode
            decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract response (after [/INST])
            if "[/INST]" in decoded:
                response = decoded.split("[/INST]")[-1].strip()
            else:
                # Model failed to generate proper response (likely due to truncation issues)
                response = ""
                pred = None
                
                # Store result with error flag
                results.append({
                    'index': i,
                    'question': question,
                    'choices': choices,
                    'valid_options': sorted(list(valid_options)),
                    'num_options': num_options,
                    'correct_answer': correct_answer,
                    'predicted_answer': None,
                    'full_response': '[ERROR: No response generated - likely truncation issue]',
                    'is_correct': False,
                    'error': 'truncation'
                })
                
                # Print warning for first few cases
                if i < 10 and len([r for r in results if r.get('error') == 'truncation']) <= 3:
                    print(f"\n⚠️ Warning: Question {i} - No valid response generated (truncation issue)")
                
                continue
            
            # Extract prediction using valid options for this question
            pred = extract_answer_from_response(response, valid_options)
            
            # Compare with correct answer
            is_correct = (pred == correct_answer)
            if is_correct:
                correct += 1
            
            # Store result
            results.append({
                'index': i,
                'question': question,
                'choices': choices,
                'valid_options': sorted(list(valid_options)),
                'num_options': num_options,
                'correct_answer': correct_answer,
                'predicted_answer': pred,
                'full_response': response,
                'is_correct': is_correct
            })
            
            # Print first 5 examples for debugging
            if i < 5:
                print(f"\n{'='*60}")
                print(f"Example {i+1}")
                print(f"{'='*60}")
                print(f"Question: {question[:80]}...")
                print(f"Valid options: {sorted(list(valid_options))} ({num_options} options)")
                print(f"Correct: {correct_answer}")
                print(f"Predicted: {pred}")
                print(f"Response: '{response}'")
                print(f"Status: {'✓ CORRECT' if is_correct else '✗ WRONG'}")
    
    accuracy = correct / total
    
    print("\n" + "="*80)
    print("EVALUATION COMPLETE")
    print("="*80)
    print(f"Total samples: {total}")
    print(f"Correct predictions: {correct}")
    
    # Count truncation errors
    truncation_errors = sum(1 for r in results if r.get('error') == 'truncation')
    if truncation_errors > 0:
        print(f"⚠️ Truncation errors: {truncation_errors} ({truncation_errors/total*100:.1f}%)")
        print(f"   These questions were too long and couldn't be processed properly")
        print(f"   Effective accuracy (excluding errors): {correct}/{total-truncation_errors} = {correct/(total-truncation_errors)*100:.1f}%")
    
    print(f"Accuracy: {accuracy:.2%}")
    print("\nOption Distribution:")
    for num_opts in sorted(option_counts.keys()):
        count = option_counts[num_opts]
        print(f"  {num_opts} options: {count} questions ({count/total*100:.1f}%)")
    print("="*80 + "\n")
    
    return accuracy, results


In [10]:
# ============================================================================
# RUN EVALUATION
# ============================================================================
accuracy, results = evaluate_mcq(
    model=model,
    tokenizer=tokenizer,
    dataset=test_dataset,
    max_samples=MAX_SAMPLES
)

Evaluating on 2512 samples...


Evaluating:   0%|          | 1/2512 [00:01<45:19,  1.08s/it]


Example 1
Question: Steps of the scientific method include all of the following except...
Valid options: ['A', 'B', 'C', 'D'] (4 options)
Correct: D
Predicted: D
Response: 'D.
Brie'
Status: ✓ CORRECT


Evaluating:   0%|          | 2/2512 [00:01<26:35,  1.57it/s]


Example 2
Question: Why do scientists call the Big Bang a theory?...
Valid options: ['A', 'B', 'C', 'D'] (4 options)
Correct: C
Predicted: C
Response: 'C.
Question:'
Status: ✓ CORRECT


Evaluating:   0%|          | 3/2512 [00:01<20:33,  2.03it/s]


Example 3
Question: The data collected in an experiment should always be...
Valid options: ['A', 'B', 'C', 'D'] (4 options)
Correct: D
Predicted: D
Response: 'D.
Question:'
Status: ✓ CORRECT


Evaluating:   0%|          | 4/2512 [00:02<18:03,  2.32it/s]


Example 4
Question: Which of the following is not a scientific model?...
Valid options: ['A', 'B', 'C', 'D'] (4 options)
Correct: B
Predicted: B
Response: 'B. A B is'
Status: ✓ CORRECT


Evaluating:   0%|          | 5/2512 [00:02<16:21,  2.55it/s]


Example 5
Question: If the results of an experiment disprove a hypothesis, then the...
Valid options: ['A', 'B', 'C', 'D'] (4 options)
Correct: D
Predicted: B
Response: 'B.
Both'
Status: ✗ WRONG


Evaluating: 100%|██████████| 2512/2512 [10:37<00:00,  3.94it/s]


EVALUATION COMPLETE
Total samples: 2512
Correct predictions: 2060
Accuracy: 82.01%

Option Distribution:
  2 options: 912 questions (36.3%)
  3 options: 1 questions (0.0%)
  4 options: 1273 questions (50.7%)
  5 options: 13 questions (0.5%)
  6 options: 4 questions (0.2%)
  7 options: 301 questions (12.0%)
  8 options: 8 questions (0.3%)






In [11]:
# ============================================================================
# DETAILED ANALYSIS
# ============================================================================
print("="*80)
print("DETAILED ANALYSIS")
print("="*80)

# Accuracy by number of options
from collections import defaultdict
accuracy_by_options = defaultdict(lambda: {'correct': 0, 'total': 0})

for result in results:
    num_opts = result['num_options']
    accuracy_by_options[num_opts]['total'] += 1
    if result['is_correct']:
        accuracy_by_options[num_opts]['correct'] += 1

print("\nAccuracy by Number of Options:")
for num_opts in sorted(accuracy_by_options.keys()):
    stats = accuracy_by_options[num_opts]
    acc = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
    print(f"  {num_opts} options: {stats['correct']}/{stats['total']} = {acc:.2%}")

# Show incorrect predictions
print("\n" + "="*80)
print("ERROR ANALYSIS")
print("="*80)

incorrect = [r for r in results if not r['is_correct']]
print(f"\nIncorrect predictions: {len(incorrect)}/{len(results)} ({len(incorrect)/len(results)*100:.1f}%)")

if len(incorrect) > 0:
    print("\nFirst 5 incorrect predictions:")
    for i, result in enumerate(incorrect[:5]):
        print(f"\n{'-'*60}")
        print(f"{i+1}. Question: {result['question'][:100]}...")
        print(f"   Valid options: {result['valid_options']}")
        print(f"   Correct: {result['correct_answer']}")
        print(f"   Predicted: {result['predicted_answer']}")
        print(f"   Response: '{result['full_response']}'")

# Check for cases where model didn't produce valid options
no_prediction = [r for r in results if r['predicted_answer'] is None]
if no_prediction:
    print(f"\n⚠️ Warning: {len(no_prediction)} cases where no valid option was extracted")
    print("First 3 cases:")
    for i, result in enumerate(no_prediction[:3]):
        print(f"\n{i+1}. Response: '{result['full_response']}'")
        print(f"   Valid options were: {result['valid_options']}")

print("="*80)

DETAILED ANALYSIS

Accuracy by Number of Options:
  2 options: 753/912 = 82.57%
  3 options: 0/1 = 0.00%
  4 options: 1023/1273 = 80.36%
  5 options: 4/13 = 30.77%
  6 options: 4/4 = 100.00%
  7 options: 269/301 = 89.37%
  8 options: 7/8 = 87.50%

ERROR ANALYSIS

Incorrect predictions: 452/2512 (18.0%)

First 5 incorrect predictions:

------------------------------------------------------------
1. Question: If the results of an experiment disprove a hypothesis, then the...
   Valid options: ['A', 'B', 'C', 'D']
   Correct: D
   Predicted: B
   Response: 'B.
Both'

------------------------------------------------------------
2. Question: Which of the following are good measures to follow when working in the field?...
   Valid options: ['A', 'B', 'C', 'D']
   Correct: D
   Predicted: A
   Response: 'A, B, C'

------------------------------------------------------------
3. Question: A theory will still remain even if conflicting data is discovered....
   Valid options: ['A', 'B']
   Corre

In [12]:
# ============================================================================
# SAVE RESULTS (OPTIONAL)
# ============================================================================
SAVE_RESULTS = True

if SAVE_RESULTS:
    import json
    import pandas as pd
    
    # Convert sets to lists for JSON serialization
    for result in results:
        result['valid_options'] = sorted(list(result['valid_options']))
    
    # Save as JSON
    with open("evaluation_results.json", "w") as f:
        json.dump({
            'accuracy': accuracy,
            'total_samples': len(results),
            'correct': sum(1 for r in results if r['is_correct']),
            'accuracy_by_options': {
                str(k): {'correct': v['correct'], 'total': v['total'], 
                        'accuracy': v['correct']/v['total']}
                for k, v in accuracy_by_options.items()
            },
            'results': results
        }, f, indent=2)
    
    # Save as CSV
    df = pd.DataFrame(results)
    df.to_csv("evaluation_test_results.csv", index=False)
    
    print(f"\n✓ Results saved to evaluation_results.json and evaluation_results.csv")

print("\nDone! 🎉")



✓ Results saved to evaluation_results.json and evaluation_results.csv

Done! 🎉
