## Step 0: Mounting Google Drive and Importing Libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/multimodal-xray-agent
!ls

In [None]:
!pip install -U trl bitsandbytes -q

In [None]:
!pip install sacremoses -q

In [None]:
import os
import json
import torch
import random
import shutil
import pandas as pd

from tqdm import tqdm
from pathlib import Path
from trl import SFTTrainer
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict, concatenate_datasets, Dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, get_peft_model_state_dict
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, EvalPrediction, default_data_collator

%load_ext tensorboard

## Step 1: Verifying GPU and Environment

In [5]:
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    device = torch.device("cuda")
    print(f"GPU detected: {device_name}")
else:
    device = torch.device("cpu")
    print("GPU not detected. Falling back to CPU.")

print(f"Running on device: {device}")

GPU detected: NVIDIA A100-SXM4-40GB
Running on device: cuda


## Step 2: Load & Preprocess Full Q/A Dataset

In [6]:
# Setting paths

PROJECT_ROOT = Path("/content/drive/MyDrive/multimodal-xray-agent")
QA_DIR = PROJECT_ROOT / "data" / "qapairs"

TRAIN_PATH = QA_DIR / "train.jsonl"
VAL_PATH = QA_DIR / "val.jsonl"

ADAPTER_SAVE_PATH = PROJECT_ROOT / "models" / "biogpt_lora_adapter"
OUTPUT_PATH = PROJECT_ROOT / "data" / "qapairs" / "validation_predictions.jsonl"
METRICS_PATH = PROJECT_ROOT / "logs" / "epoch_metrics.csv"

SOURCE_LOG_DIR = Path("./logs")
DEST_LOG_DIR = PROJECT_ROOT / "logs" / "biogpt_qlora_run"

DEST_LOG_DIR.mkdir(parents=True, exist_ok=True)
ADAPTER_SAVE_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
METRICS_PATH.parent.mkdir(parents=True, exist_ok=True)

In [None]:
# Load train and val datasets

train_raw = load_dataset("json", data_files=TRAIN_PATH.as_posix(), split="train")
val_raw = load_dataset("json", data_files=VAL_PATH.as_posix(), split="train")

In [8]:
len(train_raw), len(val_raw)

(1248, 312)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large")

In [10]:
tokenizer.pad_token = tokenizer.eos_token

In [11]:
# Preprocessing function for causal LM
def preprocess(example):
    prompt = f"### Question:\n{example['question']}\n\n### Answer:\n{example['answer']}"
    tokenized = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=256,
    )
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

In [None]:
# Tokenize train and validation datasets
train_dataset = train_raw.map(
    preprocess,
    batched=False,
    remove_columns=train_raw.column_names,
    load_from_cache_file=False,
    keep_in_memory=True,
)

eval_dataset = val_raw.map(
    preprocess,
    batched=False,
    remove_columns=val_raw.column_names,
    load_from_cache_file=False,
    keep_in_memory=True,
)

In [13]:
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(eval_dataset)}")

Train dataset size: 1248
Validation dataset size: 312


In [14]:
print(train_dataset[0])

{'input_ids': [2, 2045, 2045, 2045, 4950, 32691, 52, 4871, 34, 8, 11439, 15950, 752, 2045, 2045, 2045, 2454, 5895, 953, 52, 13156, 885, 11, 7459, 452, 463, 16754, 3869, 126, 449, 7719, 113, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0

## Step 3: Model + Tokenizer Setup (QLoRA + FlashAttention)

In [15]:
# Set quantization config for QLoRA
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",  # NormalFloat4: best for LLMs
    bnb_4bit_compute_dtype=torch.float16,
)

In [None]:
# Load BioGPT model with FlashAttention if supported
base_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/BioGPT-Large",
    quantization_config=bnb_config,
    device_map="auto"
)

In [17]:
base_model = prepare_model_for_kbit_training(base_model)

## Step 4: LoRA Configuration + PEFT Wrapping

In [18]:
# Target modules for GPT2-style transformers (BioGPT)
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]

In [19]:
# LoRA configuration (QLoRA-optimized)
peft_config = LoraConfig(
    r=64,                          # Rank of the LoRA decomposition
    lora_alpha=16,                # Scaling factor
    target_modules=target_modules,
    lora_dropout=0.05,            # Regularization
    bias="none",                  # Do not fine-tune bias terms
    task_type="CAUSAL_LM",        # Language modeling
)

In [20]:
# Inject LoRA adapters into the base model
model = get_peft_model(base_model, peft_config)

In [21]:
model.gradient_checkpointing_enable()

In [22]:
model.config.use_cache = False

In [23]:
tokenizer.pad_token = tokenizer.eos_token

In [41]:
tokenizer.padding_side = "left"

In [24]:
model.resize_token_embeddings(len(tokenizer))

BioGptScaledWordEmbedding(57717, 1600, padding_idx=1)

In [25]:
model.print_trainable_parameters()

trainable params: 88,473,600 || all params: 1,659,662,400 || trainable%: 5.3308


In [26]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): BioGptForCausalLM(
      (biogpt): BioGptModel(
        (embed_tokens): BioGptScaledWordEmbedding(57717, 1600, padding_idx=1)
        (embed_positions): BioGptLearnedPositionalEmbedding(2050, 1600)
        (layers): ModuleList(
          (0-47): 48 x BioGptDecoderLayer(
            (self_attn): BioGptSdpaAttention(
              (k_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=1600, out_features=1600, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1600, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=1600, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): Par

## Step 5: TrainingArguments configuration

In [27]:
training_args = TrainingArguments(
    output_dir="./models/lora_adapter",       # Save path
    per_device_train_batch_size=8,            # Empirically stable for A100 with QLoRA
    per_device_eval_batch_size=4,             # Same for validation
    gradient_accumulation_steps=4,            # Effective batch size = 12 × 2 = 24
    eval_strategy="epoch",                    # Evaluate once per epoch
    save_strategy="epoch",                    # Save checkpoint once per epoch
    logging_strategy="steps",                 # Log losses periodically
    logging_dir="./logs",                     # Save logs
    logging_steps=20,                         # Log every 20 steps
    num_train_epochs=5,                       # Number of fine-tuning epochs
    learning_rate=3e-4,                       # Higher LR often better for small LoRA adapters
    warmup_steps=100,                         # Small warmup to stabilize first few updates
    lr_scheduler_type="cosine",               # Smooth decay
    save_total_limit=2,                       # Retain 2 best checkpoints only
    load_best_model_at_end=True,              # Restore best checkpoint (lowest val loss)
    report_to="tensorboard",                  # Log to TensorBoard
    run_name="biogpt-qlora-run",              # Appears in TensorBoard dashboard
    fp16=True,                                # Use mixed precision (saves memory, faster)
    group_by_length=False,                    # Efficient packing of similar-length samples
    gradient_checkpointing=True,              # Redundant with model setup, but safe to keep
    eval_accumulation_steps=1,
    remove_unused_columns=False               # Required for TRL's SFTTrainer
)

## Step 8: Adding Perplexity as an Evaluation Metric

In [28]:
def compute_metrics(eval_pred: EvalPrediction) -> dict:
    with torch.no_grad():
        logits = torch.tensor(eval_pred.predictions).cpu()
        labels = torch.tensor(eval_pred.label_ids).cpu()

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous().long()

        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1)
        )

        perplexity = torch.exp(loss)
        return {
            "eval_loss": loss.item(),
            "eval_perplexity": perplexity.item()
        }

##  Step 6: Fine-Tuning the BioGPT Model with LoRA using SFTTrainer

In [None]:
# Define the trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    compute_metrics=compute_metrics
)

In [30]:
trainer.train()

Epoch,Training Loss,Validation Loss,Perplexity
1,4.7915,0.580075,1.786169
2,0.5571,0.315742,1.371274
3,0.3337,0.207221,1.230254
4,0.2411,0.19027,1.209576
5,0.2209,0.188066,1.206912


TrainOutput(global_step=195, training_loss=0.819185122465476, metrics={'train_runtime': 2075.751, 'train_samples_per_second': 3.006, 'train_steps_per_second': 0.094, 'total_flos': 1.4990714339328e+16, 'train_loss': 0.819185122465476})

In [None]:
%reload_ext tensorboard
%tensorboard --logdir ./logs

In [None]:
for file in SOURCE_LOG_DIR.glob("*"):
    shutil.copy(file, DEST_LOG_DIR)

## Step 7: Save LoRA Adapter Weights

In [34]:
trainer.save_model(ADAPTER_SAVE_PATH.as_posix())
print(f"LoRA adapter saved to: {ADAPTER_SAVE_PATH}")

LoRA adapter saved to: /content/drive/MyDrive/multimodal-xray-agent/models/biogpt_lora_adapter


## Step 8: Generate Validation Predictions

In [48]:
# Load validation samples
with open(VAL_PATH, "r") as f:
    samples = [json.loads(line) for line in f]

In [49]:
# Switch model to eval mode and disable gradients
model.eval()
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7b1cac41be50>

In [50]:
# Prep batched dataloader
eval_loader = DataLoader(
    samples, batch_size=8, shuffle=False, collate_fn=lambda batch: batch
)

results = []

for batch in tqdm(eval_loader, desc="Batched Generation"):
    prompts = [f"### Question:\n{item['question']}\n\n### Answer:\n" for item in batch]
    references = [item["answer"] for item in batch]
    uuids = [item["uuid"] for item in batch]

    # Tokenize batched prompts
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)

    # Batched generation
    output_ids = model.generate(
        **inputs,
        max_new_tokens=64,
        do_sample=False,
        num_beams=1,
        pad_token_id=tokenizer.eos_token_id,
    )

    # Decode each and clean
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    cleaned_outputs = [out[len(prompt):].strip() if out.startswith(prompt) else out for prompt, out in zip(prompts, outputs)]

    for uuid, q, r, g in zip(uuids, prompts, references, cleaned_outputs):
        results.append({
            "uuid": uuid,
            "question": q,
            "reference_answer": r,
            "generated_answer": g,
        })

Batched Generation: 100%|██████████| 39/39 [00:53<00:00,  1.36s/it]


In [54]:
prompt = "State the impression clearly in two sentences."

# Tokenize raw prompt only
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate answer
output_ids = model.generate(
    **inputs,
    max_new_tokens=64,
    do_sample=False,
    num_beams=1,
    pad_token_id=tokenizer.eos_token_id,
)

# Decode
decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(decoded)

State the impression clearly in two sentences. # # Answer: 1. Heart size is normal. 2. Lungs are clear. 3. No acute cardiopulmonary abnormality. 4. No acute cardiopulmonary abnormality.


In [51]:
results

[{'uuid': 'iu_1888',
  'question': '### Question:\nWhat is the radiologic impression?\n\n### Answer:\n',
  'reference_answer': 'No acute cardiopulmonary abnormalities',
  'generated_answer': '# # # Question: What is the radiologic impression? # # # Answer: 1. No acute cardiopulmonary abnormality'},
 {'uuid': 'iu_1888',
  'question': '### Question:\nSummarize the key thoracic findings.\n\n### Answer:\n',
  'reference_answer': 'No acute cardiopulmonary abnormalities',
  'generated_answer': '# # # Question: Summarize the key thoracic findings. # # # Answer: 1. No acute cardiopulmonary abnormality'},
 {'uuid': 'iu_1888',
  'question': '### Question:\nState the impression.\n\n### Answer:\n',
  'reference_answer': 'Normal chest X-ray; no acute cardiopulmonary issues.',
  'generated_answer': '# # # Question: State the impression. # # # Answer: No acute cardiopulmonary abnormality.'},
 {'uuid': 'iu_5141',
  'question': '### Question:\nWhat is the radiologic impression?\n\n### Answer:\n',
  're

In [None]:
state_dict = get_peft_model_state_dict(model)
for name, param in state_dict.items():
    print(name, param.abs().mean().item())

In [43]:
# Save predictions
with open(OUTPUT_PATH, "w") as f:
    for example in results:
        f.write(json.dumps(example) + "\n")

print(f"Saved validation predictions to {OUTPUT_PATH}")

Saved validation predictions to /content/drive/MyDrive/multimodal-xray-agent/data/qapairs/validation_predictions.jsonl


## Step 9: Final Metrics + Summary Reporting

In [45]:
# Extract training + eval logs (every log step)
records = trainer.state.log_history

# Convert to DataFrame
df = pd.DataFrame(records)

# Filter only epoch-level logs (those with eval_loss or epoch key)
epoch_logs = df[df["epoch"].notnull()][["epoch", "loss", "eval_loss", "eval_perplexity"]]

# Drop duplicates and keep last record per epoch (in case of multiple entries)
epoch_logs = epoch_logs.groupby("epoch").last().reset_index()

# Save
epoch_logs.to_csv(METRICS_PATH, index=False)

print(f"Epoch-level metrics saved to: {METRICS_PATH.resolve()}")
display(epoch_logs)

Epoch-level metrics saved to: /content/drive/MyDrive/multimodal-xray-agent/logs/epoch_metrics.csv


Unnamed: 0,epoch,loss,eval_loss,eval_perplexity
0,0.512821,4.7915,,
1,1.0,,0.580075,1.786169
2,1.025641,0.7693,,
3,1.538462,0.5571,,
4,2.0,,0.315742,1.371274
5,2.051282,0.4287,,
6,2.564103,0.3337,,
7,3.0,,0.207221,1.230254
8,3.076923,0.2511,,
9,3.589744,0.2411,,
