# Fine-tune FLAN-T5-Base for Clinical Query Summarization

This notebook fine-tunes the `google/flan-t5-base` model on a clinical query summarization dataset.
We will:
1. Load and parse the custom JSON data.
2. Generate baseline summaries with the pre-trained model.
3. Fine-tune the model.
4. Generate summaries with the fine-tuned model.
5. Evaluate and compare using ROUGE and BERTScore.

In [24]:
# Install necessary libraries
# Uncomment the line below if you need to install these packages
!pip3 install transformers datasets evaluate rouge_score bert_score torch accelerate

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting bert_score
Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting matplotlib (from bert_score)
Collecting matplotlib (from bert_score)
  Downloading matplotlib-3.10.7-cp313-cp313-macosx_11_0_arm64.whl.metadata (11 kB)
  Downloading matplotlib-3.10.7-cp313-cp313-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib->bert_score)
  Downloading contourpy-1.3.3-cp313-cp313-macosx_11_0_arm64.whl.metadata (5.5 kB)
Collecting contourpy>=1.0.1 (from matplotlib->bert_score)
  Downloading contourpy-1.3.3-cp313-cp313-macosx_11_0_arm64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib->bert_score)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting cycler>=0.10 (from matplotlib->bert_score)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib->bert_score)
  Dow

In [16]:
import json
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import evaluate
import numpy as np
import os

print("Libraries imported successfully!")

Libraries imported successfully!


In [17]:
# Define paths
DATA_DIR = "/Users/tapajit/projects/clinical-query-summarization/flan-t5-base/data"
TRAIN_FILE = os.path.join(DATA_DIR, "train.json")
VAL_FILE = os.path.join(DATA_DIR, "validation.json")
TEST_FILE = os.path.join(DATA_DIR, "test.json")
MODEL_CHECKPOINT = "google/flan-t5-base"
OUTPUT_DIR = "./flan-t5-finetuned-clinical"

print(f"Data directory: {DATA_DIR}")
print(f"Model checkpoint: {MODEL_CHECKPOINT}")

Data directory: /Users/tapajit/projects/clinical-query-summarization/flan-t5-base/data
Model checkpoint: google/flan-t5-base


In [18]:
def parse_data(file_path):
    """
    Parses the nested JSON data into a list of dictionaries.
    Each dictionary contains 'article' (input) and 'summary' (target).
    """
    print(f"Loading data from {file_path}...")
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    parsed_data = []
    for q_id, q_data in data.items():
        answers = q_data.get("answers", {})
        for ans_id, ans_data in answers.items():
            article = ans_data.get("article", "")
            summary = ans_data.get("answer_abs_summ", "")
            
            if article and summary:
                parsed_data.append({
                    "article": article,
                    "summary": summary,
                    "id": ans_id
                })
    
    print(f"Loaded {len(parsed_data)} examples from {file_path}")
    return parsed_data

# Load the datasets
train_data = parse_data(TRAIN_FILE)
val_data = parse_data(VAL_FILE)
test_data = parse_data(TEST_FILE)

# Convert to Hugging Face Datasets
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)
test_dataset = Dataset.from_list(test_data)

print("Datasets created successfully.")

Loading data from /Users/tapajit/projects/clinical-query-summarization/flan-t5-base/data/train.json...
Loaded 392 examples from /Users/tapajit/projects/clinical-query-summarization/flan-t5-base/data/train.json
Loading data from /Users/tapajit/projects/clinical-query-summarization/flan-t5-base/data/validation.json...
Loaded 51 examples from /Users/tapajit/projects/clinical-query-summarization/flan-t5-base/data/validation.json
Loading data from /Users/tapajit/projects/clinical-query-summarization/flan-t5-base/data/test.json...
Loaded 109 examples from /Users/tapajit/projects/clinical-query-summarization/flan-t5-base/data/test.json
Datasets created successfully.


In [19]:
# Initialize Tokenizer and Model
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
    device = "mps"
model = model.to(device)

print(f"Model loaded on {device}")

Loading tokenizer and model...
Model loaded on mps
Model loaded on mps


In [27]:
# Baseline Generation Function
def generate_summaries(model, tokenizer, dataset, device, max_input_length=1024, max_target_length=128):
    print("Generating summaries...")
    predictions = []
    references = []
    
    # We'll use a subset for quick testing if needed, but here we do the full test set
    # Using a simple loop for clarity and "human-like" feel
    for i, example in enumerate(dataset):
        input_text = "Summarize the following clinical text: " + example["article"]
        inputs = tokenizer(input_text, return_tensors="pt", max_length=max_input_length, truncation=True).to(device)
        
        with torch.no_grad():
            outputs = model.generate(inputs.input_ids, max_length=max_target_length, num_beams=4, early_stopping=True)
        
        decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
        predictions.append(decoded_output)
        references.append(example["summary"])
        
        if (i + 1) % 50 == 0:
            print(f"Processed {i + 1} examples...")
            
    return predictions, references

# Run Baseline
print("Running baseline inference on test set...")
baseline_preds, references = generate_summaries(model, tokenizer, test_dataset, device)

# Save Baseline Results
baseline_results = []
for i in range(len(baseline_preds)):
    baseline_results.append({
        "id": test_dataset[i]["id"],
        "article": test_dataset[i]["article"],
        "reference_summary": references[i],
        "baseline_summary": baseline_preds[i]
    })

with open("baseline_results.json", "w") as f:
    json.dump(baseline_results, f, indent=2)
print("Baseline results saved to baseline_results.json")

Running baseline inference on test set...
Generating summaries...
Processed 50 examples...
Processed 50 examples...
Processed 100 examples...
Processed 100 examples...
Baseline results saved to baseline_results.json
Baseline results saved to baseline_results.json


In [25]:
# Evaluate Baseline
print("Evaluating baseline...")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")

rouge_results = rouge.compute(predictions=baseline_preds, references=references)
print("Baseline ROUGE scores:", rouge_results)

# BERTScore calculation might take a while
bertscore_results = bertscore.compute(predictions=baseline_preds, references=references, lang="en")
print(f"Baseline BERTScore F1 Mean: {np.mean(bertscore_results['f1'])}")

Evaluating baseline...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Matplotlib is building the font cache; this may take a moment.
Matplotlib is building the font cache; this may take a moment.


Baseline ROUGE scores: {'rouge1': np.float64(0.1781592138425609), 'rouge2': np.float64(0.06418875928507285), 'rougeL': np.float64(0.1523361099760447), 'rougeLsum': np.float64(0.15194549927768705)}


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Baseline BERTScore F1 Mean: 0.8450807781394468


In [26]:
# Preprocessing for Training
def preprocess_function(examples):
    inputs = ["Summarize the following clinical text: " + doc for doc in examples["article"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    labels = tokenizer(examples["summary"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("Tokenizing datasets...")
tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_val = val_dataset.map(preprocess_function, batched=True)
tokenized_test = test_dataset.map(preprocess_function, batched=True)

print("Tokenization complete.")

Tokenizing datasets...


Map: 100%|██████████| 392/392 [00:00<00:00, 940.23 examples/s]
Map: 100%|██████████| 392/392 [00:00<00:00, 940.23 examples/s]
Map: 100%|██████████| 51/51 [00:00<00:00, 1447.38 examples/s]
Map: 100%|██████████| 51/51 [00:00<00:00, 1447.38 examples/s]
Map: 100%|██████████| 109/109 [00:00<00:00, 1731.95 examples/s]

Tokenization complete.





In [29]:
# Fine-tuning Setup
print("Setting up training arguments...")

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    eval_strategy="epoch", # Updated from evaluation_strategy for newer transformers versions
    learning_rate=2e-5,
    per_device_train_batch_size=4,  # Adjust based on GPU memory
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(), # Use mixed precision if on GPU
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    report_to="none" # Disable wandb/mlflow for simplicity
)

# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Initialize Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

print("Trainer initialized.")

Setting up training arguments...
Trainer initialized.


  trainer = Seq2SeqTrainer(


In [None]:
# Train the model
print("Starting training...")
trainer.train()
print("Training complete.")

# Save the model
trainer.save_model(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")

Starting training...




Epoch,Training Loss,Validation Loss


In [None]:
# Inference with Fine-tuned Model
print("Running inference with fine-tuned model...")

# Load the fine-tuned model from the checkpoint directory
# We need to make sure we are loading the model that was just trained
# The trainer saves checkpoints to OUTPUT_DIR/checkpoint-XXX
# We should load from the latest checkpoint or the output dir if save_model was called

# Force reload from the output directory where trainer.save_model() saved the final model
finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(OUTPUT_DIR)
finetuned_model = finetuned_model.to(device)

finetuned_preds, _ = generate_summaries(finetuned_model, tokenizer, test_dataset, device)

# Save Fine-tuned Results
finetuned_results = []
for i in range(len(finetuned_preds)):
    finetuned_results.append({
        "id": test_dataset[i]["id"],
        "article": test_dataset[i]["article"],
        "reference_summary": references[i],
        "baseline_summary": baseline_preds[i],
        "finetuned_summary": finetuned_preds[i]
    })

with open("finetuned_results.json", "w") as f:
    json.dump(finetuned_results, f, indent=2)
print("Fine-tuned results saved to finetuned_results.json")

In [None]:
# Final Evaluation and Comparison
print("Evaluating fine-tuned model...")

# --- ROUGE (fine-tuned) ---
ft_rouge_results = rouge.compute(
    predictions=finetuned_preds,
    references=references
)
print("Fine-tuned ROUGE scores:")
print(f"  ROUGE-1:    {ft_rouge_results['rouge1']:.4f}")
print(f"  ROUGE-2:    {ft_rouge_results['rouge2']:.4f}")
print(f"  ROUGE-L:    {ft_rouge_results['rougeL']:.4f}")
print(f"  ROUGE-Lsum: {ft_rouge_results['rougeLsum']:.4f}")

# --- BERTScore (fine-tuned) ---
P, R, F1 = bertscore(finetuned_preds, references, lang="en")
ft_bert_f1 = F1.mean().item()
print(f"\nFine-tuned BERTScore F1 Mean: {ft_bert_f1:.4f}")

# --- Comparison ---
print("\n--- Comparison ---")

print("Baseline ROUGE scores:")
print(f"  ROUGE-1:    {rouge_results['rouge1']:.4f}")
print(f"  ROUGE-2:    {rouge_results['rouge2']:.4f}")
print(f"  ROUGE-L:    {rouge_results['rougeL']:.4f}")
print(f"  ROUGE-Lsum: {rouge_results['rougeLsum']:.4f}")

print("\nFine-tuned ROUGE scores:")
print(f"  ROUGE-1:    {ft_rouge_results['rouge1']:.4f}")
print(f"  ROUGE-2:    {ft_rouge_results['rouge2']:.4f}")
print(f"  ROUGE-L:    {ft_rouge_results['rougeL']:.4f}")
print(f"  ROUGE-Lsum: {ft_rouge_results['rougeLsum']:.4f}")

# Baseline BERTScore
base_bert_f1 = float(np.mean(bertscore_results["f1"]))
print(f"\nBaseline BERTScore F1: {base_bert_f1:.4f}")
print(f"Fine-tuned BERTScore F1: {ft_bert_f1:.4f}")
