In [None]:
# @title 1. Install Libraries and Login

# Install necessary packages for transformers, training, and Hugging Face integration
!pip install -U -q transformers datasets accelerate peft bitsandbytes trl torch huggingface_hub ipywidgets Jinja2 tqdm openai

# Suppress warning messages to reduce clutter in output
import warnings
warnings.filterwarnings("ignore")

# Core Python and ML imports
import os
import json
import torch
import time
import logging
from tqdm.auto import tqdm

# Hugging Face & PEFT related imports for model loading and training
from datasets import Dataset, load_dataset
from transformers import (
    Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer,
    BitsAndBytesConfig, DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
)

# Hugging Face auth + OpenAI integration
from huggingface_hub import notebook_login, HfApi
from google.colab import userdata, files
import openai

# Configure logging for better tracking/debugging of model and API events
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s [%(levelname)s] %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("huggingface_hub").setLevel(logging.INFO)
logging.getLogger("openai").setLevel(logging.INFO)

print("--- Hugging Face Login (Optional - Needed for Pushing Adapters) ---")
try:
    # Try to fetch HF token from Colab Secrets
    hf_token = userdata.get('HF_HUB_TOKEN')
    if hf_token:
        print("Using HF token from Colab Secrets.")
        from huggingface_hub import login
        login(token=hf_token)
    else:
        # Fallback to interactive login if no token in secrets
        print("HF_HUB_TOKEN secret not found. Using interactive login:")
        notebook_login()
except Exception as e:
    print(f"HF login failed: {e}.")
print("-" * 30)

print("--- Configuring OpenAI Client ---")
try:
    # Try loading OpenAI key securely from secrets
    openai_api_key = userdata.get('OPENAI_API_KEY')
    if not openai_api_key:
        raise ValueError("OpenAI API Key not found in Colab Secrets. Please add it under the name 'OPENAI_API_KEY'.")
    openai.api_key = openai_api_key

    from openai import OpenAI
    client = OpenAI(api_key=openai.api_key)
    print("OpenAI client configured successfully.")
except Exception as e:
    print(f"ERROR setting up OpenAI Client: {e}")
    client = None
print("-" * 30)

In [None]:
# @title 2. Configuration Parameters

# Model identifiers for student (distilled) and teacher (OpenAI GPT-4) models
STUDENT_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
OPENAI_TEACHER_MODEL = "gpt-4-turbo"

# Paths for saving output models, checkpoints, and input datasets
LOCAL_OUTPUT_DIR = "/content/medQwen-0.5b-distilled-adapters"
CHECKPOINT_DIR = "/content/distill_openai_checkpoints"
HUB_REPO_ID_DISTILL = "Vidush25/medQwen-0.5b-distilled-adapters"
KB_DATA_JSON_PATH = "/content/KB_data.json"
BASE_DATASET_PATH = "/content/base_medical_dialogues.json"
KB_DATA_SIM_PATH = "/content/KB_data_sim.json"

# Training hyperparameters
NUM_EPOCHS = 1
BATCH_SIZE = 2
GRAD_ACCUM_STEPS = 8
LEARNING_RATE = 2e-5
MAX_SEQ_LENGTH = 1024

# Knowledge distillation parameters
DISTILL_SAMPLE_LIMIT = 30
OPENAI_MAX_TOKENS = 300
OPENAI_TEMPERATURE = 0.5

# LoRA configuration for parameter-efficient fine-tuning
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]  # Layer targets for LoRA

In [None]:
# @title 3. Load KB Sim Data & Prompt Dataset

# Load or construct simulation data from a medical Knowledge (KB)
KB_sim_lookup = {}

def load_or_create_KB_sim(KB_json_path):
    """
    Load precomputed KB simulation or generate it from raw JSON.
    This lookup helps relate symptoms to diseases and handle med interactions.
    """
    KB_sim_lookup_func = {}
    if os.path.exists(KB_DATA_SIM_PATH):
        try:
            with open(KB_DATA_SIM_PATH, 'r') as f:
                KB_sim_lookup_func = json.load(f)
            logger.info(f"Loaded pre-processed KB simulation from {KB_DATA_SIM_PATH}")
            return KB_sim_lookup_func
        except Exception as e:
            logger.warning(f"Error loading KB sim file {KB_DATA_SIM_PATH}: {e}. Will attempt to generate.")

    if not os.path.exists(KB_json_path):
        logger.error(f"Raw KB data file '{KB_json_path}' not found. Cannot create KB simulation.")
        return {}

    # Generate KB simulation from scratch
    try:
        with open(KB_json_path, 'r') as f:
            KB_raw_data = json.load(f)

        for entry in KB_raw_data:
            disease = entry.get('disease')
            if not disease:
                continue

            # Map each symptom to related diseases
            for symptom in entry.get('symptoms', []):
                s_lower = symptom.lower().strip()
                if s_lower:
                    KB_sim_lookup_func.setdefault(s_lower, []).append(disease)

            # Prepare medication-related data: contraindications and interactions
            med_data_map = KB_sim_lookup_func.setdefault('_medications', {})
            for med_info in entry.get("medications", []):
                med_name = med_info.get("name")
                if med_name and med_name not in med_data_map:
                    med_data_map[med_name] = {"contraindications": set(), "interactions": set()}
            for contra in entry.get("contraindications", []):
                med_name = contra.get("medication")
                condition = contra.get("condition")
                if med_name in med_data_map and condition:
                    med_data_map[med_name]["contraindications"].add(condition)
            for interact in entry.get("interactions", []):
                m1 = interact.get("med1")
                m2 = interact.get("med2")
                if m1 in med_data_map and m2:
                    med_data_map[m1]["interactions"].add(m2)
                if m2 in med_data_map and m1:
                    med_data_map[m2]["interactions"].add(m1)

        # Convert sets to lists for JSON serialization
        for key in KB_sim_lookup_func:
            if key != '_medications':
                KB_sim_lookup_func[key] = list(set(KB_sim_lookup_func[key]))
        if '_medications' in KB_sim_lookup_func:
            for med_data in KB_sim_lookup_func['_medications'].values():
                med_data["contraindications"] = list(med_data["contraindications"])
                med_data["interactions"] = list(med_data["interactions"])

        logger.info(f"Generated KB sim lookup. Symptoms: {len([k for k in KB_sim_lookup_func if k != '_medications'])}. Meds: {len(KB_sim_lookup_func.get('_medications', {}))}.")

        # Save for future use
        with open(KB_DATA_SIM_PATH, 'w') as f:
            json.dump(KB_sim_lookup_func, f, indent=2)
    except Exception as e:
        logger.error(f"Error processing KB data for simulation: {e}", exc_info=True)
        return {}

    return KB_sim_lookup_func

KB_sim_lookup = load_or_create_KB_sim(KB_DATA_JSON_PATH)
if not KB_sim_lookup:
    logger.warning("KB Simulation lookup is empty. Context generation and simulated verifiers will be limited.")
else:
    logger.info("KB Simulation data loaded/generated successfully.")

# Function to create a teacher prompt using context from the KB
def create_teacher_prompt(input_text, KB_sim_lookup):
    """
    Builds a system prompt that includes KB context if symptoms match known conditions.
    """
    matched_context = []
    for symptom in KB_sim_lookup.keys():
        if symptom != '_medications' and symptom in input_text.lower():
            diseases = KB_sim_lookup[symptom]
            matched_context.append(f"Symptom '{symptom}' may indicate: {', '.join(diseases)}.")
    context_str = "\n".join(matched_context) if matched_context else "No additional context available."

    prompt = (
        f"Patient symptoms: {input_text}\n"
        f"Knowledge Graph Context:\n{context_str}\n\n"
        "Please provide a detailed consultation. Include a step-by-step reasoning "
        "in a <think> block and finish with a final assessment in an <answer> block."
    )
    return prompt

In [None]:
# @title 4. Data Generation (Using OpenAI API as Teacher)

def get_openai_teacher_response(prompt: str, retries=2, delay=5) -> str:
    """Send prompt to OpenAI Chat API and return response text; includes basic retry logic."""
    if not client:
        logger.error("OpenAI client not initialized. Cannot make API call.")
        return None

    # Structure messages with role-based context to guide the OpenAI assistant
    messages = [
        {"role": "system", "content": "You are an expert medical AI assistant providing detailed consultations. Reason step-by-step using a <think> block. Provide your final assessment, recommendation, or next question in an <answer> block. Be factually accurate and concise."},
        {"role": "user", "content": prompt}
    ]

    for attempt in range(retries + 1):
        try:
            # Call OpenAI API with chat-completion format
            response = client.chat.completions.create(
                model=OPENAI_TEACHER_MODEL,
                messages=messages,
                max_tokens=OPENAI_MAX_TOKENS,
                temperature=OPENAI_TEMPERATURE,
                n=1,
                stop=None
            )
            # Check and return response content
            if response.choices and len(response.choices) > 0:
                content = response.choices[0].message.content
                if content:
                    logger.debug(f"OpenAI response received (attempt {attempt+1}). Length: {len(content)}")
                    return content.strip()
                else:
                    logger.warning(f"OpenAI response empty (attempt {attempt+1}).")
            else:
                logger.warning(f"Invalid response structure (attempt {attempt+1}). Response: {response}")

            # Retry if not final attempt
            if attempt < retries:
                time.sleep(delay)
            else:
                return "[ERROR: OpenAI returned empty content]"

        # Handle rate limiting and retry
        except openai.RateLimitError as e:
            logger.warning(f"OpenAI Rate Limit hit (attempt {attempt+1}/{retries+1}). Waiting {delay}s. Error: {e}")
            if attempt < retries:
                time.sleep(delay)
            else:
                logger.error("OpenAI Rate Limit exceeded after retries.")
                return f"[ERROR: OpenAI Rate Limit: {e}]"

        # Handle other API-level failures
        except openai.APIError as e:
            logger.error(f"OpenAI API Error (attempt {attempt+1}/{retries+1}): {e}", exc_info=True)
            if attempt < retries:
                time.sleep(delay)
            else:
                return f"[ERROR: OpenAI API Error: {e}]"

        # Catch-all for unexpected errors
        except Exception as e:
            logger.error(f"Unexpected error calling OpenAI API (attempt {attempt+1}): {e}", exc_info=True)
            return f"[ERROR: Unexpected OpenAI call failure: {e}]"

    return None

def prepare_distillation_data_openai(base_dataset_path, KB_sim_lookup, sample_limit):
    """Loads base data, creates prompts, and calls the OpenAI API to get teacher outputs."""
    logger.info("--- Generating Distillation Data using OpenAI Teacher API ---")
    distill_data_list = []
    prompts_generated = 0

    base_data = []
    if not os.path.exists(base_dataset_path):
        # Use fallback sample if file is missing
        logger.error(f"Base dataset not found: '{base_dataset_path}'. Using dummy entry.")
        base_data = [{"input": "Cough and fever."}]
    else:
        try:
            with open(base_dataset_path, 'r') as f:
                base_data = json.load(f)
            logger.info(f"Loaded {len(base_data)} base examples.")
        except Exception as e:
            logger.error(f"Error loading base dataset: {e}")
            base_data = [{"input": "Error loading."}]

    # Limit the number of samples to avoid exhausting API quota
    base_data = base_data[:sample_limit]
    logger.info(f"Generating teacher outputs for {len(base_data)} samples using {OPENAI_TEACHER_MODEL}...")

    for i, example in enumerate(tqdm(base_data, desc="Generating Teacher Outputs")):
        # Try multiple common keys to get the input
        input_text = example.get('input') or example.get('query') or example.get('question')
        if not input_text or not isinstance(input_text, str):
            logger.warning(f"Skipping invalid base example {i}")
            continue

        # Generate OpenAI-friendly prompt using KB context
        teacher_prompt_text = create_teacher_prompt(input_text, KB_sim_lookup)

        # Call OpenAI teacher to get response
        teacher_output_text = get_openai_teacher_response(teacher_prompt_text)
        if not teacher_output_text or teacher_output_text.startswith("[ERROR"):
            logger.error(f"Failed to get valid teacher response for sample {i}. Skipping.")
            continue

        # Store prompt-response pair for student training
        student_prompt = create_teacher_prompt(input_text, KB_sim_lookup)
        distill_data_list.append({
            "input": student_prompt,
            "output": teacher_output_text
        })

        prompts_generated += 1
        time.sleep(0.5)  # Be polite to the API and avoid hitting rate limits

    logger.info(f"Finished generation. Prepared {prompts_generated} distillation data pairs.")

    if not distill_data_list:
        return None

    try:
        return Dataset.from_list(distill_data_list)
    except Exception as e:
        logger.error(f"Failed to create Dataset: {e}")
        return None

# Trigger distillation data generation if API client is available
distillation_dataset = None
if client:
    distillation_dataset = prepare_distillation_data_openai(
        BASE_DATASET_PATH,
        KB_sim_lookup,
        sample_limit=DISTILL_SAMPLE_LIMIT
    )
else:
    logger.error("OpenAI client not available. Cannot generate distillation data.")

# Show one example to confirm data format is valid
if distillation_dataset:
    print("\n--- Sample Distillation Data Entry ---")
    print("Input Prompt:\n", distillation_dataset[0]['input'])
    print("\nTarget Output:\n", distillation_dataset[0]['output'])
else:
    print("\nERROR: No distillation data generated to proceed.")
print("-" * 30)

In [None]:
# @title 5. Load Student Model and Prepare for Training

student_model_peft = None
student_tokenizer = None

# Proceed only if distillation data exists
if distillation_dataset and len(distillation_dataset) > 0:
    print(f"\n--- Loading Student Model: {STUDENT_MODEL_ID} ---")
    try:
        # Quantization config to load model in 4-bit precision for memory efficiency
        bnb_config_student = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True
        )

        # Load base student model using quantization + automatic device placement
        student_model = AutoModelForCausalLM.from_pretrained(
            STUDENT_MODEL_ID,
            quantization_config=bnb_config_student,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16
        )

        # Load tokenizer and ensure padding is handled correctly
        student_tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL_ID, trust_remote_code=True)
        if student_tokenizer.pad_token is None:
            student_tokenizer.pad_token = student_tokenizer.eos_token
            student_tokenizer.padding_side = "left"
            student_model.config.pad_token_id = student_tokenizer.eos_token_id
            logger.info("Set student pad token and padding side.")

        # Disable caching to support gradient checkpointing during training
        student_model.config.use_cache = False

        # Prepare model for 4-bit training with gradient checkpointing
        student_model = prepare_model_for_kbit_training(student_model, use_gradient_checkpointing=True)

        # Setup LoRA config for parameter-efficient fine-tuning
        lora_config_distill = LoraConfig(
            r=LORA_R,
            lora_alpha=LORA_ALPHA,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=LORA_TARGET_MODULES
        )

        # Wrap student model with LoRA adapters
        student_model_peft = get_peft_model(student_model, lora_config_distill)
        student_model_peft.print_trainable_parameters()  # Print a summary for debugging/logging
        logger.info("Student model loaded and prepared.")

    except Exception as e:
        logger.error(f"Failed to load/prepare student model: {e}", exc_info=True)
        student_model_peft = None
else:
    logger.error("Cannot load student model - no distillation data.")
print("-" * 30)

In [None]:
# @title 6. Define Data Collator and Preprocessing Function

def preprocess_for_distill(examples, tokenizer, max_length):
    """
    Tokenizes inputs (prompt + teacher output) and sets the prompt tokens in the labels to -100.
    This version pads every sequence to max_length.
    """
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}

    for prompt, target_output in zip(examples["input"], examples["output"]):
        # Combine prompt and response, then tokenize the full sequence
        full_text = prompt + target_output + tokenizer.eos_token
        tokenized_full = tokenizer(full_text, max_length=max_length, truncation=True, padding="max_length")

        # Tokenize the prompt separately to find how many initial tokens to mask in the labels
        prompt_tokens = tokenizer(prompt, truncation=True)['input_ids']
        prompt_len = len(prompt_tokens)

        # Clone input IDs and mask out the prompt portion from contributing to the loss
        labels = tokenized_full["input_ids"].copy()
        for i in range(prompt_len - 1):  # Leave EOS token as part of the loss region
            labels[i] = -100  # Mask these positions in the loss computation

        # Append processed input, attention mask, and labels
        model_inputs["input_ids"].append(tokenized_full["input_ids"])
        model_inputs["attention_mask"].append(tokenized_full["attention_mask"])
        model_inputs["labels"].append(labels)

    return model_inputs

# Run preprocessing over the full dataset
tokenized_dataset = None
if distillation_dataset and student_tokenizer:
    logger.info("Tokenizing dataset for distillation...")
    try:
        tokenized_dataset = distillation_dataset.map(
            lambda examples: preprocess_for_distill(examples, student_tokenizer, MAX_SEQ_LENGTH),
            batched=True,
            remove_columns=distillation_dataset.column_names  # Clean up extraneous keys
        )
        logger.info(f"Dataset tokenized. Final size: {len(tokenized_dataset)}")

        # Show a quick preview of the first tokenized sequence
        if len(tokenized_dataset) > 0:
            print("\n--- Sample Tokenized Entry (Input IDs) ---")
            print(tokenized_dataset[0]['input_ids'][:60], "...")
        else:
            logger.warning("Tokenized dataset is empty.")
            tokenized_dataset = None
    except Exception as e:
        logger.error(f"Failed to tokenize dataset: {e}", exc_info=True)
else:
    logger.warning("Skipping tokenization.")

# Define data collator for language modeling loss (no masked LM)
data_collator = None
if student_tokenizer and student_model_peft:
    data_collator = DataCollatorForLanguageModeling(tokenizer=student_tokenizer, mlm=False)
    logger.info("Data collator defined using DataCollatorForLanguageModeling.")
else:
    logger.warning("Skipping data collator definition.")
print("-" * 30)

In [None]:
# @title 7. Configure and Run Training

trainer = None
print("\n--- Configuring Distillation Training ---")

# Ensure all necessary components are available before proceeding
if tokenized_dataset and student_model_peft and data_collator and student_tokenizer and len(tokenized_dataset) > 0:
    try:
        # Define training hyperparameters and logging behavior
        training_args = TrainingArguments(
            output_dir=CHECKPOINT_DIR,
            num_train_epochs=NUM_EPOCHS,
            per_device_train_batch_size=BATCH_SIZE,
            gradient_accumulation_steps=GRAD_ACCUM_STEPS,
            learning_rate=LEARNING_RATE,
            logging_dir=f"{CHECKPOINT_DIR}/logs",
            logging_strategy="steps",
            logging_steps=max(1, len(tokenized_dataset) // (BATCH_SIZE * GRAD_ACCUM_STEPS * 10)),  # log every ~10% of an epoch
            save_strategy="epoch",
            save_total_limit=1,  # keep only the most recent checkpoint
            bf16=torch.cuda.is_bf16_supported(),  # use BF16 if supported
            fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),  # fallback to FP16
            gradient_checkpointing=True,  # save memory
            report_to=["tensorboard"],  # enable TensorBoard logging
            optim="paged_adamw_8bit" if torch.cuda.is_available() else "adamw_torch",  # use 8-bit optimizer on GPU
            warmup_ratio=0.1,
            weight_decay=0.01,
            seed=42,
        )

        # Instantiate the Trainer class for supervised fine-tuning
        logger.info("Instantiating Trainer for causal LM distillation.")
        trainer = Trainer(
            model=student_model_peft,
            args=training_args,
            train_dataset=tokenized_dataset,
            data_collator=data_collator,
            tokenizer=student_tokenizer,
        )
        logger.info("Trainer initialized.")

        # Start the training process
        logger.info("--- Starting distillation training... ---")
        logger.info("Note: Training student to mimic OpenAI teacher outputs using standard cross-entropy loss.")
        train_start_time = time.time()

        try:
            train_result = trainer.train()
            logger.info("--- Distillation training finished ---")
            logger.info(f"Train Result: {train_result}")
        except Exception as e:
            logger.error("--- ERROR during distillation training ---", exc_info=True)
            train_result = None

        # Log how long training took
        train_duration = time.time() - train_start_time
        logger.info(f"Training loop duration: {train_duration:.2f}s")

        # Save the trained model locally if training completed
        if train_result is not None:
            logger.info(f"Saving distilled adapters locally to: {LOCAL_OUTPUT_DIR}")
            try:
                os.makedirs(LOCAL_OUTPUT_DIR, exist_ok=True)
                trainer.save_model(LOCAL_OUTPUT_DIR)
                student_tokenizer.save_pretrained(LOCAL_OUTPUT_DIR)
                logger.info("Distilled adapters and tokenizer saved locally.")
            except Exception as e:
                logger.error(f"Error saving adapters locally: {e}", exc_info=True)

            # Optional: push the adapters to Hugging Face Hub
            if HUB_REPO_ID_DISTILL and trainer:
                logger.info(f"Pushing distilled adapters to Hugging Face Hub: {HUB_REPO_ID_DISTILL}")
                try:
                    api = HfApi()
                    api.create_repo(repo_id=HUB_REPO_ID_DISTILL, exist_ok=True, private=True)
                    api.upload_folder(
                        folder_path=LOCAL_OUTPUT_DIR,
                        repo_id=HUB_REPO_ID_DISTILL,
                        repo_type="model",
                        commit_message=f"Distillation (Phase 1 - OpenAI Teacher {OPENAI_TEACHER_MODEL})"
                    )
                    logger.info(f"Successfully pushed adapters to {HUB_REPO_ID_DISTILL}")
                except Exception as e:
                    logger.error(f"ERROR pushing distilled adapters to Hub: {e}", exc_info=True)
        else:
            logger.warning("Skipping adapter saving/pushing because training did not complete successfully.")
    except Exception as e:
        logger.error(f"Error during training setup or Trainer initialization: {e}", exc_info=True)
else:
    logger.error("--- Skipping Training: Missing prerequisites ---")

# Cleanup to free GPU memory and avoid lingering state
try:
    del student_model, student_model_peft, trainer, tokenized_dataset, distillation_dataset
except Exception:
    pass
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    logger.info("Cleared CUDA cache.")

print("\n--- Distillation Script Finished ---")


In [None]:
# @title 9. Evaluate the distilled model

# Import required libraries for evaluation, dataset handling, model loading, etc.
import evaluate
import nltk
import numpy as np
from datasets import load_dataset, Dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import gc
import os
import json
import logging

# Define paths and constants for evaluation
DISTILLED_ADAPTERS_PATH = LOCAL_OUTPUT_DIR
BASE_STUDENT_MODEL_ID = STUDENT_MODEL_ID
TEST_DATASET_PATH = "/content/test_medical_dialogues.json"
EVAL_SAMPLE_LIMIT = 20
MAX_GENERATION_TOKENS = 350

# Ensure required tokenizer resource is available for metrics like BLEU, ROUGE
try:
    nltk.data.find('tokenizers/punkt')
    logger.info("NLTK 'punkt' resource found.")
except LookupError:
    logger.info("NLTK 'punkt' tokenizer not found. Downloading...")
    try:
        nltk.download('punkt', quiet=True)
        logger.info("NLTK 'punkt' downloaded successfully.")
        nltk.data.find('tokenizers/punkt')
        logger.info("Verified NLTK 'punkt' resource after download.")
    except Exception as download_error:
        logger.error(f"Failed to download NLTK 'punkt' resource: {download_error}", exc_info=True)

# Set evaluation device to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Evaluation using device: {device}")

# Load the test dataset from disk
logger.info(f"--- Loading Test Dataset from: {TEST_DATASET_PATH} ---")
test_data = []
if not os.path.exists(TEST_DATASET_PATH):
    logger.error(f"Test dataset not found: '{TEST_DATASET_PATH}'. Cannot perform evaluation.")
else:
    try:
        with open(TEST_DATASET_PATH, 'r') as f:
            all_data = json.load(f)
        required_input_key = 'input'
        required_output_key = 'output'

        # Filter only entries with both input and output strings
        for item in all_data:
            input_val = item.get(required_input_key)
            output_val = item.get(required_output_key)
            if isinstance(input_val, str) and input_val and isinstance(output_val, str) and output_val:
                test_data.append({"input": input_val, "output": output_val})

        logger.info(f"Loaded {len(all_data)} examples, filtered to {len(test_data)} with required keys.")
        if test_data:
            # Cap the sample size for fast evaluation
            test_data = test_data[:EVAL_SAMPLE_LIMIT]
            logger.info(f"Using {len(test_data)} samples for evaluation (limited by EVAL_SAMPLE_LIMIT).")
        else:
            logger.warning("Test dataset loaded but no valid examples found with 'input' and 'output' keys.")
    except Exception as e:
        logger.error(f"Error loading or processing test dataset: {e}")
        test_data = []

# Convert to Hugging Face Dataset format
test_dataset = None
if test_data:
    try:
        test_dataset = Dataset.from_list(test_data)
        logger.info("Test data converted to Hugging Face Dataset.")
    except Exception as e:
        logger.error(f"Failed to create test Dataset object: {e}")
else:
    logger.error("No valid test data loaded. Evaluation cannot proceed.")

print("-" * 30)

def load_model_for_eval(model_id, adapters_path=None):
    """Loads a model (optionally merging adapters) for evaluation."""
    logger.info(f"Loading model: {model_id}" + (f" with adapters from {adapters_path}" if adapters_path else ""))
    try:
        # Load model in 4-bit for efficient inference
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

        # Ensure tokenizer and model can pad properly for batching
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "left"
            model.config.pad_token_id = tokenizer.eos_token_id
            logger.info(f"Set pad token for {model_id}")

        # Load and merge LoRA adapters into the base model
        if adapters_path:
            if not os.path.isdir(adapters_path):
                logger.error(f"Adapter path does not exist or is not a directory: {adapters_path}")
                raise FileNotFoundError(f"Adapters not found at {adapters_path}")
            logger.info(f"Loading and merging PEFT adapters from {adapters_path}...")
            model = PeftModel.from_pretrained(model, adapters_path)
            model = model.merge_and_unload()
            logger.info("Adapters merged successfully.")

        model.eval()  # Set to eval mode for inference
        logger.info(f"Model {model_id}{'+Adapters' if adapters_path else ''} loaded successfully.")
        return model, tokenizer

    except Exception as e:
        logger.error(f"Failed to load model {model_id}: {e}", exc_info=True)
        return None, None

# Load the student model + distilled adapters for evaluation
logger.info("--- Loading Distilled Student Model ---")
distilled_model, distilled_tokenizer = load_model_for_eval(
    BASE_STUDENT_MODEL_ID,
    adapters_path=DISTILLED_ADAPTERS_PATH
)

# (Optional placeholder for reference model — not used here)
reference_model, reference_tokenizer = None, None

# Log GPU memory usage to help monitor resource consumption
if torch.cuda.is_available():
    logger.info(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    logger.info(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

print("-" * 30)

In [None]:
# @title 10. Generation Function

import re

def generate_response(model, tokenizer, prompt_text, max_new_tokens=MAX_GENERATION_TOKENS):
    """
    Generates a response from a model given a prompt.
    Includes adding instruction for <answer> block if needed.
    Handles chat templates.
    """
    if model is None or tokenizer is None:
        return "[ERROR: Model or Tokenizer not loaded]"

    generation_prompt_text = prompt_text

    try:
        # Attempt to format the prompt using the chat template (if available for this tokenizer)
        messages = [{"role": "user", "content": generation_prompt_text}]
        prompt_formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception as e:
        # Fallback to raw prompt if tokenizer doesn't support chat templates
        logger.warning(f"Could not apply chat template for {tokenizer.name_or_path}: {e}. Using raw prompt.")
        prompt_formatted = generation_prompt_text

    try:
        # Ensure the input fits within the model's max sequence length
        max_input_length = MAX_SEQ_LENGTH - max_new_tokens - 5
        inputs = tokenizer(
            prompt_formatted,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_input_length
        ).to(model.device)

        # Guard against prompt tokenization resulting in an empty sequence
        if inputs['input_ids'].shape[1] == 0:
            logger.warning(f"Input tokens resulted in empty sequence for prompt: {generation_prompt_text[:100]}...")
            return "[ERROR: Empty token sequence]"

        # Generation configuration: top-p sampling with temperature
        gen_kwargs = {
            "max_new_tokens": max_new_tokens,
            "temperature": 0.6,
            "top_p": 0.9,
            "do_sample": True,
            "pad_token_id": tokenizer.pad_token_id,
            "eos_token_id": tokenizer.eos_token_id,
        }

        # Generate text without tracking gradients
        with torch.no_grad():
            outputs = model.generate(**inputs, **gen_kwargs)

        # Strip the prompt tokens and decode only the generated part
        output_tokens = outputs[0][inputs['input_ids'].shape[1]:]
        response = tokenizer.decode(output_tokens, skip_special_tokens=True)

        return response.strip()

    except Exception as e:
        logger.error(f"Error during generation with {tokenizer.name_or_path}: {e}", exc_info=True)
        return f"[ERROR: Generation failed - {e}]"

# Run a quick sanity check on generation if model and data are ready
if distilled_model and distilled_tokenizer and test_dataset and len(test_dataset) > 0:
    logger.info("--- Testing generation (distilled model) ---")
    sample_prompt_test = test_dataset[0]['input']
    sample_output_test = generate_response(distilled_model, distilled_tokenizer, sample_prompt_test)
    logger.info(f"Sample Prompt:\n{sample_prompt_test}")
    logger.info(f"Sample Distilled Output:\n{sample_output_test}")

print("-" * 30)

In [None]:
# @title 11. Run Evaluation Loop and Calculate Metrics

import re
import evaluate
import nltk
import numpy as np
import torch
import gc
import logging

def extract_answer(text):
    """Extracts text content from within <answer>...</answer> tags."""
    if not isinstance(text, str): return ""
    match = re.search(r'<answer>(.*?)</answer>', text, re.IGNORECASE | re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return text.strip()

logger.info("--- Loading Evaluation Metrics (BLEU, ROUGE) ---")
try:
    bleu_metric = evaluate.load("bleu")
    rouge_metric = evaluate.load("rouge")
    metrics_loaded = True
    logger.info("Metrics loaded successfully.")
except Exception as e:
    logger.error(f"Failed to load evaluation metrics: {e}", exc_info=True)
    metrics_loaded = False

# Container for storing both raw and processed generation results
results = {
    "distilled": {"predictions": [], "references": [], "processed_predictions": [], "processed_references": []},
}

# Only proceed if we have metrics + model + data
if metrics_loaded and test_dataset and distilled_model and distilled_tokenizer:
    logger.info(f"--- Running Evaluation on {len(test_dataset)} samples (Distilled Model Only) ---")
    for example in tqdm(test_dataset, desc="Evaluating Distilled Model"):
        prompt = example['input']
        raw_reference_text = example['output']

        # Extract clean reference text (within <answer> tags if present)
        processed_reference = extract_answer(raw_reference_text)
        results["distilled"]["references"].append(raw_reference_text)
        results["distilled"]["processed_references"].append(processed_reference)

        # Generate model prediction and clean the response
        raw_distilled_pred = generate_response(distilled_model, distilled_tokenizer, prompt)
        processed_distilled_pred = extract_answer(raw_distilled_pred)
        results["distilled"]["predictions"].append(raw_distilled_pred)
        results["distilled"]["processed_predictions"].append(processed_distilled_pred)

    logger.info("--- Evaluation Generation Complete ---")
    logger.info("--- Calculating Metrics (using extracted <answer> content) ---")

    final_scores = {}
    model_name = "distilled"

    predictions = results[model_name]["processed_predictions"]
    references = results[model_name]["processed_references"]

    # Ensure matching lengths before metric calculation
    if not predictions or not references or len(predictions) != len(references):
        logger.error(f"Mismatch or empty lists for processed {model_name}. Cannot calculate scores.")
        final_scores[model_name] = {"error": "Invalid processed prediction/reference data"}
    else:
        predictions = [str(p) for p in predictions]
        references = [str(r) for r in references]

        try:
            # BLEU expects a list of lists of references
            bleu_references_formatted = [[ref] for ref in references]
            bleu_score = bleu_metric.compute(predictions=predictions, references=bleu_references_formatted)
            logger.info(f"{model_name.capitalize()} BLEU Score: {bleu_score}")
        except Exception as e:
            logger.error(f"Error calculating BLEU for {model_name}: {e}", exc_info=True)
            bleu_score = {"bleu": "Error"}

        try:
            rouge_score = rouge_metric.compute(predictions=predictions, references=references, use_aggregator=True)
            logger.info(f"{model_name.capitalize()} ROUGE Scores: {rouge_score}")
        except Exception as e:
            logger.error(f"Error calculating ROUGE for {model_name}: {e}", exc_info=True)
            rouge_score = {"rouge1": "Error", "rouge2": "Error", "rougeL": "Error", "rougeLsum": "Error"}

        final_scores[model_name] = {**bleu_score, **rouge_score}

    # Print the final evaluation metrics
    print("\n--- Evaluation Results (Based on <answer> content) ---")
    print(f"Evaluated on {len(test_dataset)} samples.")
    print("-" * 25)

    if model_name in final_scores:
        scores = final_scores[model_name]
        print(f"\n>> {model_name.capitalize()} Model ({BASE_STUDENT_MODEL_ID}+Adapters):")
        if "error" in scores:
            print(f"  Error calculating scores: {scores['error']}")
        else:
            bleu_val = scores.get('bleu')
            r1_val = scores.get('rouge1')
            r2_val = scores.get('rouge2')
            rl_val = scores.get('rougeL')
            print(f"  BLEU: {bleu_val:.4f}" if isinstance(bleu_val, (int, float)) else f"  BLEU: {bleu_val}")
            print(f"  ROUGE-1: {r1_val:.4f}" if isinstance(r1_val, (int, float)) else f"  ROUGE-1: {r1_val}")
            print(f"  ROUGE-2: {r2_val:.4f}" if isinstance(r2_val, (int, float)) else f"  ROUGE-2: {r2_val}")
            print(f"  ROUGE-L: {rl_val:.4f}" if isinstance(rl_val, (int, float)) else f"  ROUGE-L: {rl_val}")
            print("-" * 25)
    else:
        print(f"\n>> No scores calculated for {model_name.capitalize()} Model.")

    # Display one full example of generation and comparison
    if len(test_dataset) > 0 and "distilled" in results:
        print("\n--- Example Generations (Sample 0) ---")
        print(f"Prompt:\n{test_dataset[0]['input']}\n")
        print(f"Ground Truth (Raw):\n{results['distilled']['references'][0]}\n")
        print(f"Ground Truth (Processed <answer>):\n{results['distilled']['processed_references'][0]}\n")
        print(f"Distilled Model (Raw):\n{results['distilled']['predictions'][0]}\n")
        print(f"Distilled Model (Processed <answer>):\n{results['distilled']['processed_predictions'][0]}\n")
        print("-" * 30)

else:
    logger.error("Evaluation skipped due to missing prerequisites (metrics, test data, or distilled model/tokenizer).")

# Free memory and clear CUDA cache
logger.info("Cleaning up evaluation models...")
try:
    if 'distilled_model' in locals(): del distilled_model
    if 'distilled_tokenizer' in locals(): del distilled_tokenizer
except NameError:
    pass

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    logger.info("Cleared CUDA cache after evaluation.")

print("\n--- Evaluation Script Finished ---")