# T5-Small Instruction Fine-Tuning for Clinical Queries

Fine-tune Hugging Face `t5-small` on the custom clinical QA dataset under `t5-small/data`.
The notebook also records zero-shot baselines, evaluates with ROUGE + BERTScore, and saves JSON predictions for later review.

In [None]:
%pip install -q transformers datasets evaluate accelerate bert-score rouge-score
print("Finished installing the project dependencies.")


In [None]:
import json
import random
from pathlib import Path

import numpy as np
import torch
from datasets import Dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Trainer, TrainingArguments
import evaluate
from tqdm.auto import tqdm

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

BASE_DIR = Path('t5-small')
DATA_DIR = BASE_DIR / 'data'
OUTPUT_DIR = BASE_DIR / 'outputs'
PRED_DIR = OUTPUT_DIR / 'predictions'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
PRED_DIR.mkdir(parents=True, exist_ok=True)

MODEL_NAME = 't5-small'
MAX_INPUT_LENGTH = 1024
MAX_TARGET_LENGTH = 256
GEN_MAX_NEW_TOKENS = 160

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'Data directory: {DATA_DIR.resolve()}')

In [None]:
def load_split(json_path):
    with open(json_path, 'r') as f:
        raw_data = json.load(f)
    records = []
    for qid, payload in raw_data.items():
        question = (payload.get('question') or '').strip()
        answers = payload.get('answers', {})
        for aid, answer_payload in answers.items():
            article = (answer_payload.get('article') or '').strip()
            summary = (answer_payload.get('answer_abs_summ') or '').strip()
            if not article or not summary:
                continue
            prompt = (
                "Summarize the following medical article to answer the clinical question.\n"
                f"Question: {question}\n"
                f"Article: {article}"
            )
            records.append(
                {
                    'id': f'{qid}_{aid}',
                    'question': question,
                    'article': article,
                    'prompt': ''.join(prompt),
+        
,
    return records

In [None]:
train_records = load_split(DATA_DIR / 'train.json')
val_records = load_split(DATA_DIR / 'validation.json')
test_records = load_split(DATA_DIR / 'test.json')
print(f'Train/Val/Test sizes -> {len(train_records)} / {len(val_records)} / {len(test_records)}')

def peek(records, name):
    if not records:
        print(f'No records available for {name}.')
        return
    print(f"\n{name} sample prompt (first 300 chars):\n{records[0]['prompt'][:300]}...")
    print(f"{name} sample reference:\n{records[0]['summary']}\n")

peek(train_records, 'Train')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print('Tokenizer loaded.')

hf_train = Dataset.from_list(train_records)
hf_val = Dataset.from_list(val_records)
hf_test = Dataset.from_list(test_records)

def tokenize_batch(batch):
    model_inputs = tokenizer(
        batch['prompt'],
        max_length=MAX_INPUT_LENGTH,
        truncation=True,
        padding='max_length'
    )
    labels = tokenizer(
        batch['summary'],
        max_length=MAX_TARGET_LENGTH,
        truncation=True,
        padding='max_length'
    )
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

tokenized_train = hf_train.map(tokenize_batch, batched=True, remove_columns=hf_train.column_names, desc='Tokenizing train')
tokenized_val = hf_val.map(tokenize_batch, batched=True, remove_columns=hf_val.column_names, desc='Tokenizing val')
tokenized_test = hf_test.map(tokenize_batch, batched=True, remove_columns=hf_test.column_names, desc='Tokenizing test')

In [None]:
rouge = evaluate.load('rouge')
bertscore = evaluate.load('bertscore')
print('Loaded ROUGE and BERTScore evaluators.')

def compute_text_metrics(preds, refs):
    rouge_result = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
    bert_result = bertscore.compute(predictions=preds, references=refs, lang='en')
    metrics = {f'rouge_{k}': round(v, 4) for k, v in rouge_result.items()}
    metrics['bertscore_f1'] = float(np.mean(bert_result['f1']))
    return metrics

def run_batch_generation(model, records, split_name, output_path, batch_size=4):
    model.eval()
    predictions = []
    iterator = range(0, len(records), batch_size)
    for start in tqdm(iterator, desc=f'Generating {split_name}', leave=False):
        batch = records[start:start + batch_size]
        if not batch:
            continue
        inputs = tokenizer(
            [row['prompt'] for row in batch],
            max_length=MAX_INPUT_LENGTH,
            truncation=True,
            padding=True,
            return_tensors='pt'
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=GEN_MAX_NEW_TOKENS,
                num_beams=4,
                length_penalty=1.0,
            )
        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        for row, pred in zip(batch, decoded):
            predictions.append(
                {
                    'id': row['id'],
                    'question': row['question'],
                    'prediction': pred.strip(),
                    'reference': row['summary'],
                }
            )
    with open(output_path, 'w') as f:
        json.dump(predictions, f, indent=2)
    print(f'Saved {len(predictions)} predictions to {output_path}')
    return predictions

In [None]:
baseline_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
print('Running baseline generation on validation data...')
baseline_val_results = run_batch_generation(
    baseline_model,
    val_records,
    'baseline-val',
    PRED_DIR / 'baseline_val_predictions.json'
)
baseline_val_metrics = compute_text_metrics(
    [row['prediction'] for row in baseline_val_results],
    [row['reference'] for row in baseline_val_results]
)
print('Baseline validation metrics:', baseline_val_metrics)

print('Running baseline generation on test data...')
baseline_test_results = run_batch_generation(
    baseline_model,
    test_records,
    'baseline-test',
    PRED_DIR / 'baseline_test_predictions.json'
)
baseline_test_metrics = compute_text_metrics(
    [row['prediction'] for row in baseline_test_results],
    [row['reference'] for row in baseline_test_results]
)
print('Baseline test metrics:', baseline_test_metrics)

In [None]:
finetune_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=finetune_model)

training_args = TrainingArguments(
    output_dir=str(OUTPUT_DIR / 't5_small_finetune'),
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=3e-4,
    num_train_epochs=1,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_steps=50,
    predict_with_generate=True,
    generation_max_length=GEN_MAX_NEW_TOKENS,
    load_best_model_at_end=True,
    warmup_steps=100,
    fp16=torch.cuda.is_available(),
    report_to=[]
)

def trainer_compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels == -100, tokenizer.pad_token_id, labels)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return compute_text_metrics(decoded_preds, decoded_labels)

trainer = Trainer(
    model=finetune_model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=trainer_compute_metrics,
)

train_result = trainer.train()
trainer.save_model(OUTPUT_DIR / 't5_small_finetune')
tokenizer.save_pretrained(OUTPUT_DIR / 't5_small_finetune')
print('Training finished and model saved.')

In [None]:
def decode_sequences(sequences):
    decoded = tokenizer.batch_decode(sequences, skip_special_tokens=True)
    return [text.strip() for text in decoded]

val_predictions = trainer.predict(tokenized_val)
val_decoded_preds = decode_sequences(val_predictions.predictions)
val_labels = np.where(val_predictions.label_ids == -100, tokenizer.pad_token_id, val_predictions.label_ids)
val_decoded_refs = decode_sequences(val_labels)
finetuned_val_metrics = compute_text_metrics(val_decoded_preds, val_decoded_refs)
print('Fine-tuned validation metrics:', finetuned_val_metrics)

finetuned_val_results = []
for record, pred, ref in zip(val_records, val_decoded_preds, val_decoded_refs):
    finetuned_val_results.append(
        {
            'id': record['id'],
            'question': record['question'],
            'prediction': pred,
            'reference': ref,
        }
    )
with open(PRED_DIR / 'finetuned_val_predictions.json', 'w') as f:
    json.dump(finetuned_val_results, f, indent=2)
print('Saved fine-tuned validation predictions.')

test_predictions = trainer.predict(tokenized_test)
test_decoded_preds = decode_sequences(test_predictions.predictions)
test_labels = np.where(test_predictions.label_ids == -100, tokenizer.pad_token_id, test_predictions.label_ids)
test_decoded_refs = decode_sequences(test_labels)
finetuned_test_metrics = compute_text_metrics(test_decoded_preds, test_decoded_refs)
print('Fine-tuned test metrics:', finetuned_test_metrics)

finetuned_test_results = []
for record, pred, ref in zip(test_records, test_decoded_preds, test_decoded_refs):
    finetuned_test_results.append(
        {
            'id': record['id'],
            'question': record['question'],
            'prediction': pred,
            'reference': ref,
        }
    )
with open(PRED_DIR / 'finetuned_test_predictions.json', 'w') as f:
    json.dump(finetuned_test_results, f, indent=2)
print('Saved fine-tuned test predictions.')

In [None]:
def summarize_split(name, baseline_metrics, finetuned_metrics):
    metric_keys = sorted(set(baseline_metrics) | set(finetuned_metrics))
    print(f"\n{name} metrics")
    print("Metric".ljust(20), "Baseline".ljust(12), "Finetuned".ljust(12), "Delta")
    print("-" * 60)
    for key in metric_keys:
        base_val = baseline_metrics.get(key)
        tune_val = finetuned_metrics.get(key)
        delta = None if (base_val is None or tune_val is None) else tune_val - base_val
        print(
            key.ljust(20),
            f"{base_val:.4f}".ljust(12) if base_val is not None else "--".ljust(12),
            f"{tune_val:.4f}".ljust(12) if tune_val is not None else "--".ljust(12),
            f"{delta:+.4f}" if delta is not None else "--",
        )

required_vars = [
    'baseline_val_metrics',
    'baseline_test_metrics',
    'finetuned_val_metrics',
    'finetuned_test_metrics',
]
missing = [var for var in required_vars if var not in globals()]
if missing:
    raise RuntimeError(
        "Please execute the baseline and fine-tuned evaluation cells before running this comparison block."
    )

summarize_split('Validation', baseline_val_metrics, finetuned_val_metrics)
summarize_split('Test', baseline_test_metrics, finetuned_test_metrics)
print("\nDone comparing baseline vs. fine-tuned performance.")
