<a href="https://colab.research.google.com/github/viknes86/Alternative-Assignment-Medical-VQA-Comparison-25056315/blob/main/CNN_LSTM_Baseline_FinalModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Baseline Model: CNN-LSTM for Medical VQA
## Advanced Machine Learning - Final Project
**Student Names:** J.Vikneswaran A/L Palaniandy
**Student ID:** 25056315

### Project Objective
To establish a discriminative baseline for the VQA-RAD dataset using a classic **CNN-LSTM** architecture.
* **Visual Encoder:** ResNet50 (Pretrained on ImageNet) to extract visual features.
* **Question Encoder:** LSTM (Long Short-Term Memory) to process text.
* **Fusion:** Element-wise multiplication of visual and textual features.
* **Classifier:** Fully Connected Layer predicting one word from a fixed vocabulary.

This baseline will be compared against the generative **LLaVA-Med** model to demonstrate the "Capacity Wall" of discriminative approaches in medical reasoning.

Mount Drive & Setup

In [None]:
# ==============================================================================
# SECTION 1: ENVIRONMENT & SETUP
# Purpose: Mount Drive and define project paths.
# ==============================================================================

import os
from google.colab import drive
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import nltk
import numpy as np

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Define Paths
# UPDATE THIS if your folder name is different
PROJECT_PATH = '/content/drive/MyDrive/AML_FinalProject'
IMAGE_DIR = os.path.join(PROJECT_PATH, 'VQA_RAD Image Folder')
JSON_FILE = os.path.join(PROJECT_PATH, 'VQA_RAD Dataset Public.json')
MODEL_SAVE_PATH = os.path.join(PROJECT_PATH, 'cnn_lstm_vqa.pth')

# 3. Device Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"âœ… Device: {device}")

Vocabulary & Dataset

In [None]:
# ==============================================================================
# SECTION 2: DATA PIPELINE
# Purpose: Build Vocabulary and Custom Dataset Class.
# ==============================================================================

class Vocabulary:
    def __init__(self):
        self.word2idx = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.idx2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.idx = 4

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        return self.word2idx.get(word, self.word2idx["<UNK>"])

    def __len__(self):
        return len(self.word2idx)

def build_vocab(json_path, threshold=1):
    """Builds vocabulary from the VQA-RAD questions and answers."""
    df = pd.read_json(json_path)
    counter = Counter()

    # Tokenize Questions & Answers
    for question in df['question']:
        tokens = nltk.tokenize.word_tokenize(str(question).lower())
        counter.update(tokens)

    # Add Answers to Vocab (Critical for classification)
    for answer in df['answer']:
        tokens = str(answer).lower().split() # Simple split for answers
        counter.update(tokens)

    # Create Vocab Object
    vocab = Vocabulary()
    for word, count in counter.items():
        if count >= threshold:
            vocab.add_word(word)

    print(f"âœ… Vocabulary Built. Total Size: {len(vocab)}")
    return vocab

# Custom Dataset
class VQARADDataset(Dataset):
    def __init__(self, json_file, img_dir, vocab, transform=None):
        self.data = pd.read_json(json_file)
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        # 1. Load Image
        img_name = item['image_name']
        if not img_name.endswith('.jpg'): img_name += '.jpg'
        image = Image.open(os.path.join(self.img_dir, img_name)).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # 2. Process Text (Question)
        tokens = nltk.tokenize.word_tokenize(str(item['question']).lower())
        q_indices = [self.vocab("<SOS>")] + [self.vocab(token) for token in tokens] + [self.vocab("<EOS>")]

        # Pad/Truncate to fixed length (e.g., 20)
        max_len = 20
        if len(q_indices) < max_len:
            q_indices += [self.vocab("<PAD>")] * (max_len - len(q_indices))
        else:
            q_indices = q_indices[:max_len]

        # 3. Process Label (Answer)
        # For CNN-LSTM, we treat VQA as classification over the Vocab
        ans_token = str(item['answer']).lower().split()[0] # Take first word as label
        label = self.vocab(ans_token)

        return image, torch.tensor(q_indices), torch.tensor(label)

# Setup Transforms (Standard ResNet Norms)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

nltk.download('punkt')
vocab = build_vocab(JSON_FILE)
dataset = VQARADDataset(JSON_FILE, IMAGE_DIR, vocab, transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
print("âœ… Dataset Ready.")

Model Architecture

In [None]:
# ==============================================================================
# SECTION 3: MODEL ARCHITECTURE (CNN-LSTM)
# Purpose: Define the Visual Encoder and Text Decoder.
# ==============================================================================

class CNN_LSTM_VQA(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=256, num_layers=1):
        super(CNN_LSTM_VQA, self).__init__()

        # 1. Visual Encoder (ResNet50)
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1] # Remove FC layer
        self.resnet = nn.Sequential(*modules)
        self.visual_fc = nn.Linear(2048, hidden_size) # Project to LSTM dimension

        # 2. Question Encoder (LSTM)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)

        # 3. Classifier
        self.classifier = nn.Linear(hidden_size, vocab_size)

    def forward(self, images, questions):
        # A. Image Features
        features = self.resnet(images)
        features = features.view(features.size(0), -1) # Flatten (Batch, 2048)
        img_embedding = self.visual_fc(features)       # (Batch, 256)

        # B. Text Features
        embeds = self.embedding(questions)             # (Batch, Seq, 128)
        _, (hidden, _) = self.lstm(embeds)             # Get final hidden state
        txt_embedding = hidden[-1]                     # (Batch, 256)

        # C. Fusion (Element-wise Multiplication)
        fused = img_embedding * txt_embedding

        # D. Prediction
        logits = self.classifier(fused)
        return logits

# Initialize Model
model = CNN_LSTM_VQA(vocab_size=len(vocab)).to(device)
print(model)

Training

In [None]:
# ==============================================================================
# SECTION 4: TRAINING (Run if retraining is needed)
# ==============================================================================

def train_model(model, loader, epochs=20):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    loss_history = []

    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for imgs, qs, labels in loader:
            imgs, qs, labels = imgs.to(device), qs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs, qs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(loader)
        loss_history.append(avg_loss)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

    # Save Model
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"âœ… Model Saved to {MODEL_SAVE_PATH}")
    return loss_history

# Uncomment to train
# history = train_model(model, dataloader)

Evaluation & Visualization

In [None]:
# ==============================================================================
# SECTION 5: EVALUATION & VISUALIZATION
# Purpose: Calculate Accuracy and visualize Predictions.
# ==============================================================================

def evaluate_model(model, loader):
    model.eval()
    correct = 0
    total = 0

    print("ðŸ”Ž Evaluating...")
    with torch.no_grad():
        for imgs, qs, labels in loader:
            imgs, qs, labels = imgs.to(device), qs.to(device), labels.to(device)

            outputs = model(imgs, qs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100 * correct / total
    print(f"âœ… Final Accuracy: {acc:.2f}%")
    return acc

# Load Saved Weights (if available)
if os.path.exists(MODEL_SAVE_PATH):
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    print("ðŸ“‚ Loaded Pre-trained Weights.")

evaluate_model(model, dataloader)

# ---------------------------------------------------------
# Visualization: Qualitative Results
# ---------------------------------------------------------
def visualize_predictions(model, dataset, num_samples=3):
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    plt.figure(figsize=(12, 4))

    for i, idx in enumerate(indices):
        img, q_tensor, label_tensor = dataset[idx]

        # Predict
        with torch.no_grad():
            output = model(img.unsqueeze(0).to(device), q_tensor.unsqueeze(0).to(device))
            pred_idx = output.argmax(1).item()

        # Decode
        q_text = " ".join([dataset.vocab.idx2word[t.item()] for t in q_tensor if t.item() not in [0,1,2]])
        truth = dataset.vocab.idx2word[label_tensor.item()]
        pred = dataset.vocab.idx2word[pred_idx]

        # Plot
        ax = plt.subplot(1, 3, i+1)
        # Denormalize image for display
        img_disp = img.permute(1, 2, 0).numpy()
        img_disp = img_disp * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])

        plt.imshow(np.clip(img_disp, 0, 1))
        plt.title(f"Q: {q_text}\nTrue: {truth} | Pred: {pred}", fontsize=10)
        plt.axis('off')

        # Color coding title
        if truth == pred:
            ax.set_title(ax.get_title(), color='green', fontweight='bold')
        else:
            ax.set_title(ax.get_title(), color='red')

    plt.tight_layout()
    plt.show()

visualize_predictions(model, dataset)