<a href="https://colab.research.google.com/github/viknes86/Alternative-Assignment-Medical-VQA-Comparison-25056315/blob/main/03_Phase2_LlavaMed_Hybrid_Specialist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical Visual Question Answering using LLaVA-Med
## Advanced Machine Learning - Final Project
**Student Names:** J.Vikneswaran A/L Palaniandy
**Student ID:** 25056315

**GitHub:** https://github.com/viknes86/Alternative-Assignment-Medical-VQA-Comparison-25056315

**Google Drive (Data & Weights):** https://drive.google.com/drive/folders/1SPnKmP3lWkdrAqWBtg1vo0aeEugNk2K7?usp=sharing

### Project Objective
To compare the performance of a Generative Visual Language Model (LLaVA-1.5-7B) against a traditional discriminative baseline (CNN-LSTM) on the VQA-RAD dataset. This notebook implements training on the final LLaVA model (Exp5) for the final showdown with the CNN-LSTM.

Mount Google Drive

In [None]:
# ==============================================================================
# SECTION 0: MOUNT GOOGLE DRIVE
# Purpose: Connect to Google Drive to access the dataset and save results.
# ==============================================================================

from google.colab import drive
import os

# Mount Drive
drive.mount('/content/drive')

# Verify the project folder exists
# UPDATE THIS PATH if your folder name is different
project_path = '/content/drive/MyDrive/AML_FinalProject'

if os.path.exists(project_path):
    print(f"‚úÖ Success! Project folder found at: {project_path}")
    os.chdir(project_path) # Set as current working directory
    print(f"üìÇ Current Working Directory: {os.getcwd()}")
else:
    print(f"‚ùå Warning: Folder not found at {project_path}")
    print("Please check your Google Drive folder name.")

Imports & Environment Setup

In [None]:
# ==============================================================================
# SECTION 1: ENVIRONMENT SETUP
# Purpose: Install the specific library versions used for training.
# ==============================================================================

# Install required packages (Exact configuration from training)
print("‚è≥ Installing Dependencies...")
!pip install -q --upgrade transformers
!pip install -q --upgrade peft
!pip install -q --upgrade accelerate
!pip install -q --upgrade bitsandbytes
!pip install -q --upgrade torch torchvision torchaudio
!pip install -q datasets nltk rouge_score matplotlib seaborn # Added for Evaluation/Plotting
!pip install bert_score
!pip install evaluate rouge_score bert_score
!pip install -U bitsandbytes accelerate peft transformers

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
from transformers.trainer_utils import get_last_checkpoint
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
from peft import LoraConfig, get_peft_model, PeftModel, TaskType
from PIL import Image
import gc
from transformers import TrainerCallback
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from bert_score import score
import evaluate


# Setup Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚úÖ Using Device: {device}")
if device == "cuda":
    print(f"‚úÖ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Path Configuration
# UPDATE THIS if your path is different
PROJECT_PATH = '/content/drive/MyDrive/AML_FinalProject'
IMAGE_DIR = os.path.join(PROJECT_PATH, 'VQA_RAD Image Folder')
JSON_FILE = os.path.join(PROJECT_PATH, 'VQA_RAD Dataset Public.json')

print("‚úÖ Environment Ready.")

Dataset Loading & Processing

In [None]:
# ==============================================================================
# SECTION 3: DATA PIPELINE
# Purpose: Load and preprocess VQA-RAD images and text.
# ==============================================================================

# --- 1. INITIALIZE PROCESSOR FIRST ---
# We need this to exist before creating the dataset!
print("‚è≥ Initializing Processor...")
# Using the same model ID you defined in your config
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# 1. Define Dataset Class (Modified to accept LIST of data)
class LlavaRADDataset(Dataset):
    def __init__(self, data_list, img_dir, processor, max_length=1024):
        self.data = data_list  # Now accepts the split list directly
        self.img_dir = img_dir
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Load Image
        img_path = os.path.join(self.img_dir, item['image_name'])
        if not img_path.endswith('.jpg'): img_path += '.jpg'
        image = Image.open(img_path).convert("RGB")

        # Prepare Text
        question = item['question']
        answer = str(item['answer'])

        # LLaVA 1.5 Prompt Format
        text_prompt = f"USER: <image>\n{question}\nASSISTANT: {answer}"

        # Tokenize
        inputs = self.processor(
            text=text_prompt,
            images=image,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )

        input_ids = inputs.input_ids[0]
        attention_mask = inputs.attention_mask[0]
        pixel_values = inputs.pixel_values[0]
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "labels": labels
        }

# 2. LOAD & SPLIT DATA
print("‚è≥ Loading and Splitting Data...")
with open(JSON_FILE, 'r') as f:
    full_data = json.load(f)

# SPLIT: 80% Train / 20% Test (Random State 42 matches Baseline)
train_data, test_data = train_test_split(full_data, test_size=0.2, random_state=42)

print(f"‚úÖ Data Split Complete:")
print(f"   - Training Samples: {len(train_data)}")
print(f"   - Test Samples:     {len(test_data)}")

# 3. Create Datasets
train_dataset = LlavaRADDataset(train_data, IMAGE_DIR, processor)
eval_dataset = LlavaRADDataset(test_data, IMAGE_DIR, processor)

#Exp 5 - The Final Refined version
Note: Training was completed previously. The code below is provided for reference/reproducibility.

In [None]:
# ==============================================================================
# SECTION 4: THE FINAL SHOWDOWN (Exp 5 - Combined Best Features)
# ==============================================================================
import torch
import os
import pandas as pd
import gc
import numpy as np
from transformers import (
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
    LlavaForConditionalGeneration,
    TrainerCallback
)
from transformers.trainer_utils import get_last_checkpoint
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# 1. Clean Slate
torch.cuda.empty_cache()
gc.collect()

# 2. Setup Live Logging
class LiveLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_local_process_zero:
            log_history = pd.DataFrame(state.log_history)
            log_history.to_csv(os.path.join(args.output_dir, "training_log_live.csv"), index=False)

# 3. Load Model (Skip-Quant Mode)
print("‚è≥ Loading Base LLaVA Model...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    llm_int8_skip_modules=["multi_modal_projector"]
)

model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    quantization_config=bnb_config,
    device_map={"": 0}
)

# --- 3a. PREPARE FOR TRAINING ---
model = prepare_model_for_kbit_training(model)

# 4. CONFIGURATION
EXP_NAME = "Exp5_Final_Validation"
OUTPUT_DIR = f"{PROJECT_PATH}/results/{EXP_NAME}"

peft_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    modules_to_save=["multi_modal_projector"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model.enable_input_require_grads()
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# --- 5. MEMORY SAVER FUNCTIONS (CRITICAL FIX) ---
# Instead of storing huge logits, we convert to simple integers immediately
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        logits = logits[0]
    return logits.argmax(dim=-1)

def compute_metrics(eval_pred):
    # Now 'predictions' are already small integers, not huge floats
    predictions, labels = eval_pred
    mask = labels != -100
    correct = (predictions[mask] == labels[mask]).sum()
    total = mask.sum()
    return {"accuracy": correct / total}

# 6. TRAINING ARGUMENTS
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2, # <--- Reduced to prevent OOM during Eval
    gradient_accumulation_steps=4,
    num_train_epochs=30,
    learning_rate=1e-4,

    # STABILITY SETTINGS
    bf16=True,
    fp16=False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",

    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    remove_unused_columns=False,
    report_to="none",
    dataloader_num_workers=4,
    dataloader_pin_memory=True
)

# 7. TRAINER
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[LiveLoggingCallback()],
    compute_metrics=compute_metrics,
    # *** THE FIX: Plug in the memory saver ***
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    data_collator=lambda x: {
        "input_ids": torch.stack([i["input_ids"] for i in x]),
        "attention_mask": torch.stack([i["attention_mask"] for i in x]),
        "pixel_values": torch.stack([i["pixel_values"] for i in x]),
        "labels": torch.stack([i["labels"] for i in x])
    }
)

# 8. RUN
last_checkpoint = get_last_checkpoint(OUTPUT_DIR)
if last_checkpoint:
    print(f"üîÑ Resuming from {last_checkpoint}")
    trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    print("üöÄ Starting Exp 5: The Final Validation (Memory Safe Mode)...")
    trainer.train()

# 9. SAVE
final_path = f"{OUTPUT_DIR}/final_best_adapter"
trainer.save_model(final_path)
print(f"‚úÖ Final Model Saved to: {final_path}")

The Final Showdown Scores


In [None]:
# ==============================================================================
# SECTION 5: FINAL EVALUATION (With BLEU, ROUGE & BERTScore)
# ==============================================================================


# --- 1. RESTORE DATA (Critical Step after Restart) ---
print("üîÑ Restoring Test Data Split...")
with open(JSON_FILE, 'r') as f:
    full_data = json.load(f)

# Consistency: Use same random state (42) as Section 3
_, test_data = train_test_split(full_data, test_size=0.2, random_state=42)
print(f"‚úÖ Data Restored. Test Set Size: {len(test_data)} images.")

# --- 2. SETUP MODEL & METRICS ---
EXP_NAME = "Exp5_Final_Validation"
ADAPTER_PATH = f"{PROJECT_PATH}/results/{EXP_NAME}/final_best_adapter"

print(f"‚è≥ Loading Metrics (BLEU, ROUGE, BERTScore)...")
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")

print(f"‚è≥ Loading Model from: {ADAPTER_PATH}")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_skip_modules=["multi_modal_projector"]
)

base_model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    quantization_config=bnb_config,
    device_map={"": 0}
)

# Merge Adapter
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model.eval()
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# --- 3. INFERENCE LOOP ---
print("üöÄ Starting Inference on TEST SET...")
results = []
pred_texts = []
ref_texts = []

for i in tqdm(range(len(test_data))):
    item = test_data[i]

    # A. Load Image
    img_path = os.path.join(IMAGE_DIR, item['image_name'])
    if not os.path.exists(img_path) and not img_path.endswith('.jpg'):
        img_path += '.jpg'

    try:
        image = Image.open(img_path).convert("RGB")
    except:
        continue # Skip if image missing

    # B. Prepare Text
    question = item['question']
    ground_truth = str(item['answer']).strip().lower()

    # C. Generate
    prompt = f"USER: <image>\n{question}\nASSISTANT:"
    inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")

    with torch.inference_mode():
        generate_ids = model.generate(
            **inputs,
            max_new_tokens=40, # Allow enough space for descriptive answers
            do_sample=False,
            temperature=0.0
        )

    # D. Decode
    generated_text = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
    answer = generated_text.split("ASSISTANT:")[-1].strip().lower()

    # --- 4. SCORING LOGIC (The Fix) ---
    is_correct = 0

    # Logic 1: Exact Match
    if answer == ground_truth:
        is_correct = 1

    # Logic 2: Soft Match (Crucial for Generative Models)
    # If the ground truth is "yes" or "no", allow "yes, it is..."
    elif ground_truth in ["yes", "no"]:
        # Clean punctuation to ensure "yes." matches "yes"
        clean_ans = answer.replace('.', '').replace(',', '').strip()
        if clean_ans.startswith(ground_truth):
            is_correct = 1

    # Store Data
    pred_texts.append(answer)
    ref_texts.append(ground_truth)

    # Identify Type based on ANSWER (since JSON tags are unreliable)
    q_category = "CLOSED" if ground_truth in ["yes", "no"] else "OPEN"

    results.append({
        "Question": question,
        "Category": q_category,
        "Ground_Truth": ground_truth,
        "Prediction": answer,
        "Correct": is_correct
    })

# --- 5. CALCULATE & PRINT REPORT ---
print("\nüìä Calculating Final Scores...")
df = pd.DataFrame(results)

# Accuracy
acc_overall = df["Correct"].mean() * 100
acc_closed = df[df["Category"] == "CLOSED"]["Correct"].mean() * 100
acc_open = df[df["Category"] == "OPEN"]["Correct"].mean() * 100

# Text Metrics
bleu_score = bleu.compute(predictions=pred_texts, references=[[r] for r in ref_texts])
rouge_score = rouge.compute(predictions=pred_texts, references=ref_texts)
# BERTScore (using distilbert for speed)
bert_score = bertscore.compute(predictions=pred_texts, references=ref_texts, lang="en", verbose=False)
bert_f1 = np.mean(bert_score['f1'])

print("\n" + "="*50)
print("üèÜ EXPERIMENT 5: FINAL SCIENTIFIC REPORT")
print("="*50)
print(f"1. Clinical Accuracy (Soft-Match):")
print(f"   - Overall Accuracy:      {acc_overall:.2f}%")
print(f"   - Closed (Yes/No) Acc:   {acc_closed:.2f}%  <-- USE FOR TABLE 3.3")
print(f"   - Open (Descriptive) Acc:{acc_open:.2f}%")
print("-" * 30)
print(f"2. Semantic Understanding:")
print(f"   - BERTScore F1: {bert_f1:.4f} (Target > 0.60)")
print("-" * 30)
print(f"3. Text Structure:")
print(f"   - BLEU Score:  {bleu_score['bleu']:.4f}")
print(f"   - ROUGE-L:     {rouge_score['rougeL']:.4f}")
print("="*50)

# Save Final CSV
output_csv = f"{PROJECT_PATH}/results/Exp5_Final_Results_Corrected.csv"
df.to_csv(output_csv, index=False)
print(f"‚úÖ Full results saved to: {output_csv}")