<a href="https://colab.research.google.com/github/viknes86/Alternative-Assignment-Medical-VQA-Comparison-25056315/blob/main/01_Baseline_CNN_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Baseline Model: CNN-LSTM for Medical VQA
## Advanced Machine Learning - Final Project
**Student Names:** J.Vikneswaran A/L Palaniandy
**Student ID:** 25056315

**GitHub:** https://github.com/viknes86/Alternative-Assignment-Medical-VQA-Comparison-25056315

**Google Drive (Data & Weights):** https://drive.google.com/drive/folders/1SPnKmP3lWkdrAqWBtg1vo0aeEugNk2K7?usp=sharing


### Project Objective
To establish a discriminative baseline for the VQA-RAD dataset using a classic **CNN-LSTM** architecture.
* **Visual Encoder:** ResNet50 (Pretrained on ImageNet) to extract visual features.
* **Question Encoder:** LSTM (Long Short-Term Memory) to process text.
* **Fusion:** Element-wise multiplication of visual and textual features.
* **Classifier:** Fully Connected Layer predicting one word from a fixed vocabulary.

This baseline will be compared against the generative **LLaVA-Med** model to demonstrate the "Capacity Wall" of discriminative approaches in medical reasoning.

Mount Drive & Setup

In [None]:
# ==============================================================================
# SECTION 1: ENVIRONMENT & SETUP
# Purpose: Mount Drive and define project paths.
# ==============================================================================



import os
from google.colab import drive
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from PIL import Image
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import nltk
import numpy as np

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Define Paths
# UPDATE THIS if your folder name is different
PROJECT_PATH = '/content/drive/MyDrive/AML_FinalProject'
IMAGE_DIR = os.path.join(PROJECT_PATH, 'VQA_RAD Image Folder')
JSON_FILE = os.path.join(PROJECT_PATH, 'VQA_RAD Dataset Public.json')
MODEL_SAVE_PATH = os.path.join(PROJECT_PATH, 'cnn_lstm_vqa_split.pth')

# 3. Device Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Device: {device}")

Vocabulary & Dataset

In [None]:
# ==============================================================================
# SECTION 2: DATA PIPELINE (UPDATED WITH TRAIN/TEST SPLIT)
# Purpose: Build Vocabulary and Custom Dataset Class.
# ==============================================================================

nltk.download('punkt')
nltk.download('punkt_tab')

# 1. Vocabulary Class (Unchanged)
class Vocabulary:
    def __init__(self):
        self.word2idx = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.idx2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.idx = 4

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        return self.word2idx.get(word, self.word2idx["<UNK>"])

    def __len__(self):
        return len(self.word2idx)

def build_vocab(json_path, threshold=1):
    df = pd.read_json(json_path)
    counter = Counter()
    for question in df['question']:
        tokens = nltk.tokenize.word_tokenize(str(question).lower())
        counter.update(tokens)
    for answer in df['answer']:
        tokens = str(answer).lower().split()
        counter.update(tokens)
    vocab = Vocabulary()
    for word, count in counter.items():
        if count >= threshold: vocab.add_word(word)
    print(f"‚úÖ Vocabulary Built. Total Size: {len(vocab)}")
    return vocab

# 2. Dataset Class (Unchanged)
class VQARADDataset(Dataset):
    def __init__(self, json_file, img_dir, vocab, transform=None):
        self.data = pd.read_json(json_file)
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        img_name = item['image_name']
        if not img_name.endswith('.jpg'): img_name += '.jpg'
        image = Image.open(os.path.join(self.img_dir, img_name)).convert('RGB')
        if self.transform: image = self.transform(image)
        tokens = nltk.tokenize.word_tokenize(str(item['question']).lower())
        q_indices = [self.vocab("<SOS>")] + [self.vocab(token) for token in tokens] + [self.vocab("<EOS>")]
        max_len = 20
        if len(q_indices) < max_len: q_indices += [self.vocab("<PAD>")] * (max_len - len(q_indices))
        else: q_indices = q_indices[:max_len]
        ans_token = str(item['answer']).lower().split()[0]
        label = self.vocab(ans_token)
        return image, torch.tensor(q_indices), torch.tensor(label)

# 3. SPLIT LOGIC (New!)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

vocab = build_vocab(JSON_FILE)
full_dataset = VQARADDataset(JSON_FILE, IMAGE_DIR, vocab, transform)

# Split 80% Train / 20% Test
# using a fixed seed (42) so the test set is always the same!
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"‚úÖ Data Split: {len(train_dataset)} Training samples | {len(test_dataset)} Test samples")

Model Architecture

In [None]:
# ==============================================================================
# SECTION 3: MODEL ARCHITECTURE (CNN-LSTM)
# Purpose: Define the Visual Encoder and Text Decoder.
# ==============================================================================

class CNN_LSTM_VQA(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=256, num_layers=1):
        super(CNN_LSTM_VQA, self).__init__()

        # 1. Visual Encoder (ResNet50)
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1] # Remove FC layer
        self.resnet = nn.Sequential(*modules)
        self.visual_fc = nn.Linear(2048, hidden_size) # Project to LSTM dimension

        # 2. Question Encoder (LSTM)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)

        # 3. Classifier
        self.classifier = nn.Linear(hidden_size, vocab_size)

    def forward(self, images, questions):
        # A. Image Features
        features = self.resnet(images)
        features = features.view(features.size(0), -1) # Flatten (Batch, 2048)
        img_embedding = self.visual_fc(features)       # (Batch, 256)

        # B. Text Features
        embeds = self.embedding(questions)             # (Batch, Seq, 128)
        _, (hidden, _) = self.lstm(embeds)             # Get final hidden state
        txt_embedding = hidden[-1]                     # (Batch, 256)

        # C. Fusion (Element-wise Multiplication)
        fused = img_embedding * txt_embedding

        # D. Prediction
        logits = self.classifier(fused)
        return logits

# Initialize Model
model = CNN_LSTM_VQA(vocab_size=len(vocab)).to(device)
print(model)

Training

In [None]:
# ==============================================================================
# SECTION 4: TRAINING (Run if retraining is needed)
# ==============================================================================

def train_model(model, loader, epochs=20):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    loss_history = []

    model.train()
    print(f"üöÄ Starting Training for {epochs} epochs...")

    for epoch in range(epochs):
        epoch_loss = 0
        # Use the passed loader (which will be train_loader)
        for imgs, qs, labels in loader:
            imgs, qs, labels = imgs.to(device), qs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs, qs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(loader)
        loss_history.append(avg_loss)
        print(f"   Epoch [{epoch+1}/{epochs}] Loss: {avg_loss:.4f}")

    # Save the model
    torch.save({
        'state_dict': model.state_dict(),
        'vocab': vocab.word2idx,         # Save vocab dict
        'answer_to_idx': vocab.word2idx  # Save map
    }, MODEL_SAVE_PATH)
    print(f"üíæ Model Saved to {MODEL_SAVE_PATH}")

    # Plot Loss
    plt.plot(loss_history)
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

# TRAIN ON TRAIN_LOADER ONLY
# Re-initialize model first to start fresh
model = CNN_LSTM_VQA(vocab_size=len(vocab)).to(device) # Note: Cell 3 definition
train_model(model, train_loader, epochs=20)

Evaluation & Visualization

In [None]:
# ==============================================================================
# SECTION 5: EVALUATION & VISUALIZATION
# Purpose: Calculate Accuracy and visualize Predictions.
# ==============================================================================

def evaluate_model(model, loader):
    model.eval()
    correct = 0
    total = 0

    print("üîé Evaluating on TEST SET...")
    with torch.no_grad():
        for imgs, qs, labels in loader:
            imgs, qs, labels = imgs.to(device), qs.to(device), labels.to(device)

            outputs = model(imgs, qs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100 * correct / total
    print(f"‚úÖ Final Test Accuracy: {acc:.2f}%")
    return acc

# --- 1. Load Weights Correctly (Handle Dictionary) ---
if os.path.exists(MODEL_SAVE_PATH):
    checkpoint = torch.load(MODEL_SAVE_PATH, map_location=device)

    # Check if the file contains the new dictionary format or just weights
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint)

    print("üìÇ Loaded Pre-trained Weights successfully.")
else:
    print("‚ö†Ô∏è No saved model found. Using current model state.")

# --- 2. Evaluate on TEST LOADER (Not full dataloader) ---
evaluate_model(model, test_loader)

# ---------------------------------------------------------
# Visualization: Qualitative Results (Test Set Only)
# ---------------------------------------------------------
def visualize_predictions(model, dataset, num_samples=3):
    model.eval()
    # Handle case where test set is smaller than num_samples
    actual_samples = min(len(dataset), num_samples)
    indices = np.random.choice(len(dataset), actual_samples, replace=False)

    plt.figure(figsize=(12, 4))

    for i, idx in enumerate(indices):
        img, q_tensor, label_tensor = dataset[idx]

        # Predict
        with torch.no_grad():
            output = model(img.unsqueeze(0).to(device), q_tensor.unsqueeze(0).to(device))
            pred_idx = output.argmax(1).item()

        # Decode
        # Note: We need the original 'vocab' object to decode.
        # If 'vocab' isn't available globally, we'd need to pass it in.
        q_text = " ".join([vocab.idx2word[t.item()] for t in q_tensor if t.item() not in [0,1,2]])
        truth = vocab.idx2word[label_tensor.item()]
        pred = vocab.idx2word[pred_idx]

        # Plot
        ax = plt.subplot(1, 3, i+1)
        # Denormalize image for display
        img_disp = img.permute(1, 2, 0).numpy()
        img_disp = img_disp * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])

        plt.imshow(np.clip(img_disp, 0, 1))
        plt.title(f"Q: {q_text}\nTrue: {truth} | Pred: {pred}", fontsize=10)
        plt.axis('off')

        # Color coding title
        if truth == pred:
            ax.set_title(ax.get_title(), color='green', fontweight='bold')
        else:
            ax.set_title(ax.get_title(), color='red')

    plt.tight_layout()
    plt.show()

# --- 3. Visualize TEST DATASET (Not full dataset) ---
visualize_predictions(model, test_dataset)

##FINAL EVALUATION CELL: CNN-LSTM COMPARISON METRICS
To compare with LLaVA-MED model

In [None]:
# ==========================================
# FINAL EVALUATION CELL: CNN-LSTM COMPARISON METRICS
# ==========================================

# 1. Install Metrics Libraries (if not already installed)
!pip install rouge-score nltk sacrebleu
!pip install bert_score

import torch
import numpy as np
import pandas as pd
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from bert_score import score  # <--- NEW IMPORT
from tqdm import tqdm
from torchvision import transforms # Ensure transforms is imported

from torch.utils.data import random_split

class SimpleCNN_LSTM(nn.Module):
    def __init__(self, vocab_size, embed_size=256, hidden_size=512):
        super(SimpleCNN_LSTM, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.visual_fc = nn.Linear(2048, hidden_size)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.classifier = nn.Linear(hidden_size, vocab_size) # Simple Classifier

    def forward(self, images, questions):
        features = self.resnet(images).view(images.size(0), -1)
        img_features = self.visual_fc(features)
        embeds = self.embedding(questions)
        _, (hidden, _) = self.lstm(embeds)
        txt_features = hidden[-1]
        fused = img_features * txt_features # Multiplication Fusion
        return self.classifier(fused)

class ComplexCNN_LSTM(nn.Module):
    def __init__(self, vocab_in, vocab_out, embed_size=256, hidden_size=512):
        super(ComplexCNN_LSTM, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.vision_linear = nn.Linear(2048, hidden_size)
        self.bn_vision = nn.BatchNorm1d(hidden_size)
        self.embedding = nn.Embedding(vocab_in, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, vocab_out)
        )

    def forward(self, images, questions):
        features = self.resnet(images).view(images.size(0), -1)
        img_features = self.bn_vision(self.vision_linear(features))
        embeds = self.embedding(questions)
        _, (hidden, _) = self.lstm(embeds)
        txt_features = hidden[-1]
        combined = torch.cat((img_features, txt_features), dim=1) # Concat Fusion
        return self.classifier(combined)

def evaluate_test_set_complete(model_path, json_path, img_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--> Loading model from {model_path}...")

    if not os.path.exists(model_path):
        print("‚ùå Error: Model file not found.")
        return

    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint

    # 1. Recover Vocab
    raw_vocab = checkpoint['vocab']
    vocab_wrapper = lambda x: raw_vocab.get(x, raw_vocab.get('<UNK>', 0))
    vocab_wrapper.word2idx = raw_vocab

    # 2. SMART ARCHITECTURE DETECTION
    if 'classifier.3.weight' in state_dict:
        print("üí° Detected Architecture: COMPLEX (Split Vocab, Concat Fusion)")
        vocab_in = state_dict['embedding.weight'].shape[0]
        vocab_out = state_dict['classifier.3.weight'].shape[0]
        embed_dim = state_dict['embedding.weight'].shape[1]
        hidden_dim = state_dict['lstm.weight_hh_l0'].shape[1]

        model = ComplexCNN_LSTM(vocab_in, vocab_out, embed_dim, hidden_dim).to(device)
        # Fix layer names for complex model
        clean_state_dict = {k.replace("vision_encoder", "resnet"): v for k, v in state_dict.items()}

    elif 'classifier.weight' in state_dict:
        print("üí° Detected Architecture: SIMPLE (Shared Vocab, Mult Fusion)")
        vocab_size = state_dict['classifier.weight'].shape[0]
        embed_dim = state_dict['embedding.weight'].shape[1]
        hidden_dim = state_dict['lstm.weight_hh_l0'].shape[1]

        model = SimpleCNN_LSTM(vocab_size, embed_dim, hidden_dim).to(device)
        # Fix layer names for simple model
        clean_state_dict = state_dict # Usually matches direct save
    else:
        print("‚ùå Error: Unknown model structure.")
        return

    # Load Weights
    try:
        model.load_state_dict(clean_state_dict, strict=False)
        print("‚úÖ Weights loaded successfully.")
    except Exception as e:
        print(f"‚ö†Ô∏è Warning during loading: {e}")

    # 3. Setup Dataset & SPLIT IT
    tfm = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    full_dataset = VQARADDataset(json_path, img_dir, transform=tfm, vocab=vocab_wrapper)

    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    _, test_dataset = random_split(full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

    print(f"--> Evaluating on TEST SET ({len(test_dataset)} samples)...")

    idx2ans = {v: k for k, v in raw_vocab.items()}

    # 4. Metrics Setup
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    smoothie = SmoothingFunction().method4

    closed_correct = 0; closed_total = 0
    open_correct = 0;   open_total = 0

    open_bleu_scores = []; open_rouge_scores = []
    open_pred_texts = []; open_true_texts = []

    model.eval()

    for i in tqdm(range(len(test_dataset))):
        try:
            original_idx = test_dataset.indices[i]
            row = full_dataset.data.iloc[original_idx]
            q_type = str(row['answer_type']).upper()

            image, question, _ = test_dataset[i]

            with torch.no_grad():
                output = model(image.unsqueeze(0).to(device), question.unsqueeze(0).to(device))
                pred_idx = output.argmax(1).item()

            pred_text = idx2ans.get(pred_idx, "unknown")
            true_text = str(row['answer']).lower().strip()
            pred_text = str(pred_text).lower().strip()

            # --- METRIC LOGIC ---
            if q_type == 'CLOSED':
                if pred_text == true_text: closed_correct += 1
                closed_total += 1

            elif q_type == 'OPEN':
                if pred_text == true_text: open_correct += 1
                open_total += 1

                open_bleu_scores.append(sentence_bleu([true_text.split()], pred_text.split(), smoothing_function=smoothie))
                open_rouge_scores.append(scorer.score(true_text, pred_text)['rougeL'].fmeasure)
                open_pred_texts.append(pred_text)
                open_true_texts.append(true_text)

        except Exception as e: continue

    # 5. Final Calculation
    closed_acc = (closed_correct / closed_total * 100) if closed_total > 0 else 0
    open_acc = (open_correct / open_total * 100) if open_total > 0 else 0

    bert_f1 = 0
    if open_pred_texts:
        P, R, F1 = score(open_pred_texts, open_true_texts, lang="en", verbose=False)
        bert_f1 = F1.mean().item()

    print(f"\n=== TEST SET METRICS ===")
    print(f"Closed Accuracy (Yes/No): {closed_acc:.2f}%")
    print(f"Open Accuracy (Exact):    {open_acc:.2f}%")
    print(f"Open BLEU-1:              {np.mean(open_bleu_scores):.4f}")
    print(f"Open ROUGE-L:             {np.mean(open_rouge_scores):.4f}")
    print(f"Open BERTScore F1:        {bert_f1:.4f}")

    # Save
    pd.DataFrame({
        "Metric": ["Closed Accuracy", "Open Accuracy", "Open BLEU-1", "Open ROUGE-L", "Open BERTScore"],
        "Value": [f"{closed_acc:.2f}%", f"{open_acc:.2f}%", np.mean(open_bleu_scores), np.mean(open_rouge_scores), bert_f1]
    }).to_csv("cnn_lstm_test_metrics_complete.csv", index=False)
    print("‚úÖ Results saved to 'cnn_lstm_test_metrics_complete.csv'")

# RUN
evaluate_test_set_complete(
    model_path=MODEL_SAVE_PATH,
    json_path=JSON_FILE,
    img_dir=IMAGE_DIR
)