# Phase 5.2: Evaluate English Medical Retention

Check that the model retains English medical capabilities (no catastrophic forgetting).

## Contents
1. Setup
2. Load Models (Korean-adapted and Original)
3. Evaluate on English Medical QA
4. Compare Results
5. Save Results

In [None]:
# Setup
import sys
import os
sys.path.append("..")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from tqdm import tqdm
import json

# GPU setup
from config.gpu_utils import setup_gpu, print_memory_usage, clear_memory
device = setup_gpu()

print_memory_usage()

In [None]:
# Directories
# Primary: Use instruction-tuned model
KOREAN_MODEL_DIR = "../models/instruction_tuned"

# Alternative: Use expanded model directly
# KOREAN_MODEL_DIR = "../models/final/korean_medgemma_expanded"

# Original MedGemma for comparison (if available)
ORIGINAL_MODEL = "../models/medgemma-4b-it"  # or "google/medgemma-4b-it"

RESULTS_DIR = "../results"

os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"Korean model: {KOREAN_MODEL_DIR}")
print(f"Original model: {ORIGINAL_MODEL}")

---
## 1. English Medical Test Questions

In [None]:
# Sample English medical questions for retention testing
english_medical_questions = [
    {
        "question": "A 45-year-old patient presents with chest pain radiating to the left arm, sweating, and shortness of breath. What is the most likely diagnosis?",
        "choices": ["A. Gastroesophageal reflux", "B. Acute myocardial infarction", "C. Panic attack", "D. Pneumonia", "E. Costochondritis"],
        "answer": "B"
    },
    {
        "question": "Which medication is considered first-line treatment for Type 2 Diabetes Mellitus?",
        "choices": ["A. Insulin glargine", "B. Metformin", "C. Glipizide", "D. Pioglitazone", "E. Sitagliptin"],
        "answer": "B"
    },
    {
        "question": "What is the most common cause of community-acquired pneumonia in adults?",
        "choices": ["A. Haemophilus influenzae", "B. Staphylococcus aureus", "C. Streptococcus pneumoniae", "D. Klebsiella pneumoniae", "E. Mycoplasma pneumoniae"],
        "answer": "C"
    },
    {
        "question": "A patient with hypertension should avoid which of the following?",
        "choices": ["A. Regular exercise", "B. High sodium diet", "C. Weight management", "D. Stress reduction", "E. Regular blood pressure monitoring"],
        "answer": "B"
    },
    {
        "question": "Which symptom is NOT typically associated with hyperthyroidism?",
        "choices": ["A. Weight loss", "B. Heat intolerance", "C. Bradycardia", "D. Tremor", "E. Anxiety"],
        "answer": "C"
    },
    {
        "question": "What is the gold standard for diagnosing peptic ulcer disease?",
        "choices": ["A. Barium swallow", "B. CT scan", "C. Upper GI endoscopy", "D. Ultrasound", "E. Blood test"],
        "answer": "C"
    },
    {
        "question": "Which vaccine is recommended annually for adults over 65?",
        "choices": ["A. HPV vaccine", "B. MMR vaccine", "C. Influenza vaccine", "D. Hepatitis B vaccine", "E. Varicella vaccine"],
        "answer": "C"
    },
    {
        "question": "A patient presents with sudden onset of severe headache, neck stiffness, and photophobia. What should be ruled out first?",
        "choices": ["A. Migraine", "B. Tension headache", "C. Subarachnoid hemorrhage", "D. Cluster headache", "E. Sinusitis"],
        "answer": "C"
    },
]

print(f"Prepared {len(english_medical_questions)} English medical questions")

---
## 2. Load Korean-Adapted Model

In [None]:
# Load Korean-adapted model
print("Loading Korean-adapted model...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

korean_model = AutoModelForCausalLM.from_pretrained(
    KOREAN_MODEL_DIR,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

korean_tokenizer = AutoTokenizer.from_pretrained(KOREAN_MODEL_DIR)

if korean_tokenizer.pad_token is None:
    korean_tokenizer.pad_token = korean_tokenizer.eos_token

korean_model.eval()

print("Korean model loaded!")
print_memory_usage()

---
## 3. Evaluate Korean Model on English Questions

In [None]:
def create_english_prompt(question_data):
    """Create prompt for English medical question"""
    
    question = question_data["question"]
    choices = "\n".join(question_data["choices"])
    
    prompt = f"""<|im_start|>system
You are a medical AI assistant. Provide accurate and helpful medical information.
<|im_end|>
<|im_start|>user
Answer the following medical question. Respond with only the letter of the correct answer (A, B, C, D, or E).

Question: {question}

Choices:
{choices}

Answer:
<|im_end|>
<|im_start|>assistant
"""
    
    return prompt

In [None]:
def extract_answer(response):
    """Extract answer letter from response"""
    response = response.strip().upper()
    
    for letter in ['A', 'B', 'C', 'D', 'E']:
        if response.startswith(letter):
            return letter
        if letter in response[:10]:
            return letter
    
    return None

In [None]:
def evaluate_model(model, tokenizer, questions, model_name="Model"):
    """Evaluate model on English medical questions"""
    
    correct = 0
    results = []
    
    for q in tqdm(questions, desc=f"Evaluating {model_name}"):
        prompt = create_english_prompt(q)
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True
        )
        
        predicted = extract_answer(response)
        is_correct = predicted == q["answer"]
        
        if is_correct:
            correct += 1
        
        results.append({
            "question": q["question"],
            "predicted": predicted,
            "correct": q["answer"],
            "is_correct": is_correct,
            "response": response,
        })
    
    accuracy = correct / len(questions) * 100
    return accuracy, results

In [None]:
# Evaluate Korean-adapted model
print("\nEvaluating Korean-adapted model on English medical questions...")

korean_accuracy, korean_results = evaluate_model(
    korean_model, 
    korean_tokenizer, 
    english_medical_questions,
    "Korean MedGemma"
)

print(f"\nKorean-adapted model accuracy: {korean_accuracy:.1f}%")

In [None]:
# Show results
print("\nEnglish Medical Question Results:")
print("=" * 60)

for i, r in enumerate(korean_results):
    status = "✓" if r["is_correct"] else "✗"
    print(f"\n{status} Q{i+1}: {r['question'][:70]}...")
    print(f"   Predicted: {r['predicted']}, Correct: {r['correct']}")

---
## 4. Compare with Original (Optional)

In [None]:
# Optional: Compare with original MedGemma
# This requires loading the original model

compare_with_original = False  # Set to True if you want to compare

if compare_with_original and os.path.exists(ORIGINAL_MODEL):
    # Clear memory first
    del korean_model
    clear_memory()
    
    print("\nLoading original MedGemma for comparison...")
    
    original_model = AutoModelForCausalLM.from_pretrained(
        ORIGINAL_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    
    original_tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL)
    if original_tokenizer.pad_token is None:
        original_tokenizer.pad_token = original_tokenizer.eos_token
    
    original_model.eval()
    
    original_accuracy, original_results = evaluate_model(
        original_model,
        original_tokenizer,
        english_medical_questions,
        "Original MedGemma"
    )
    
    print(f"\nOriginal MedGemma accuracy: {original_accuracy:.1f}%")
    print(f"Korean-adapted accuracy: {korean_accuracy:.1f}%")
    print(f"Retention: {korean_accuracy / original_accuracy * 100:.1f}%")
else:
    print("\nSkipping comparison with original model")
    original_accuracy = None

---
## 5. Qualitative English Test

In [None]:
# Test open-ended English questions
# Reload Korean model if it was deleted for comparison
if 'korean_model' not in dir():
    korean_model = AutoModelForCausalLM.from_pretrained(
        KOREAN_MODEL_DIR,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    korean_model.eval()

english_open_questions = [
    "What are the main symptoms of COVID-19?",
    "How does hypertension affect the cardiovascular system?",
    "What lifestyle changes can help manage Type 2 Diabetes?",
]

print("\nQualitative English evaluation:")
print("=" * 60)

for question in english_open_questions:
    prompt = f"""<|im_start|>system
You are a medical AI assistant.
<|im_end|>
<|im_start|>user
{question}
<|im_end|>
<|im_start|>assistant
"""
    
    inputs = korean_tokenizer(prompt, return_tensors="pt").to(korean_model.device)
    
    with torch.no_grad():
        outputs = korean_model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            do_sample=True,
            top_p=0.9,
            pad_token_id=korean_tokenizer.pad_token_id,
        )
    
    response = korean_tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True
    )
    
    print(f"\nQ: {question}")
    print(f"A: {response[:400]}...")
    print("-" * 40)

---
## 6. Save Results

In [None]:
# Save English evaluation results
english_eval_results = {
    "model": KOREAN_MODEL_DIR,
    "benchmark": "English Medical QA (Custom)",
    "korean_model_accuracy": korean_accuracy,
    "original_model_accuracy": original_accuracy,
    "retention_rate": korean_accuracy / original_accuracy * 100 if original_accuracy else None,
    "total_questions": len(english_medical_questions),
    "results": korean_results,
}

results_path = f"{RESULTS_DIR}/english_retention_eval.json"
with open(results_path, "w", encoding="utf-8") as f:
    json.dump(english_eval_results, f, ensure_ascii=False, indent=2)

print(f"\nResults saved to {results_path}")

In [None]:
print("\n" + "=" * 60)
print("English Retention Evaluation Complete!")
print("=" * 60)
print(f"\nKorean-adapted model English accuracy: {korean_accuracy:.1f}%")
if original_accuracy:
    print(f"Original model English accuracy: {original_accuracy:.1f}%")
    retention = korean_accuracy / original_accuracy * 100
    print(f"Retention rate: {retention:.1f}%")
print(f"\nResults saved to: {results_path}")
print("\nPhase 5 Complete!")
print("\nNext steps:")
print("  Run phase6_deployment/01_quantize_awq.ipynb for deployment")