In [None]:
# PART 1: SETUP + DATA PREP

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
from transformers import BertTokenizer, BertModel
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load Dataset
dataset = load_dataset("flaviagiammarino/vqa-rad")

print("Train samples:", len(dataset["train"]))
print("Test samples:", len(dataset["test"]))

# Answer Vocabulary (Top-50)
answers = [item["answer"].lower().strip() for item in dataset["train"]]
answer_freq = Counter(answers)

TOP_K = 50
top_answers = [ans for ans, _ in answer_freq.most_common(TOP_K)]

answer_to_label = {ans: idx for idx, ans in enumerate(top_answers)}
UNKNOWN_LABEL = len(answer_to_label)

def encode_answer(ans):
    ans = ans.lower().strip()
    return answer_to_label.get(ans, UNKNOWN_LABEL)

def add_label(example):
    example["label"] = encode_answer(example["answer"])
    return example

dataset = dataset.map(add_label)

# Tokenization
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
MAX_LEN = 32

def tokenize_question(example):
    encoded = tokenizer(
        example["question"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN
    )
    example["input_ids"] = encoded["input_ids"]
    example["attention_mask"] = encoded["attention_mask"]
    return example

dataset = dataset.map(tokenize_question)

# Image Preprocessing
resnet = models.resnet50(pretrained=True)
resnet = nn.Sequential(*list(resnet.children())[:-1])
resnet = resnet.to(device)
resnet.eval()

image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

def extract_image_feature(image):
    image = image_transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = resnet(image)
    return features.squeeze()

# Dataset Wrapper
class VQARadDataset(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(item["attention_mask"], dtype=torch.long),
            "label": torch.tensor(item["label"], dtype=torch.long),
            "image": item["image"]
        }

def vqa_collate_fn(batch):
    return {
        "input_ids": torch.stack([x["input_ids"] for x in batch]),
        "attention_mask": torch.stack([x["attention_mask"] for x in batch]),
        "label": torch.stack([x["label"] for x in batch]),
        "image": [x["image"] for x in batch]
    }

train_dataset = VQARadDataset(dataset["train"])
test_dataset = VQARadDataset(dataset["test"])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=vqa_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=vqa_collate_fn)

Using device: cpu


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001-eb8844602202be(…):   0%|          | 0.00/24.2M [00:00<?, ?B/s]

data/test-00000-of-00001-e5bc3d208bb4dee(…):   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1793 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/451 [00:00<?, ? examples/s]

Train samples: 1793
Test samples: 451


Map:   0%|          | 0/1793 [00:00<?, ? examples/s]

Map:   0%|          | 0/451 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/1793 [00:00<?, ? examples/s]

Map:   0%|          | 0/451 [00:00<?, ? examples/s]



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 146MB/s]


In [None]:
# PART 2: MODEL DEFINITION

text_encoder = BertModel.from_pretrained("bert-base-uncased").to(device)

image_projection = nn.Linear(2048, 768).to(device)

class MedicalVilBERT(nn.Module):
    def __init__(self, text_encoder, image_projection, num_classes):
        super().__init__()
        self.text_encoder = text_encoder
        self.image_projection = image_projection

        self.classifier = nn.Sequential(
            nn.Linear(768 * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, input_ids, attention_mask, image_features):
        text_outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        text_cls = text_outputs.last_hidden_state[:, 0, :]
        image_embeds = self.image_projection(image_features)
        fused = torch.cat([text_cls, image_embeds], dim=1)
        logits = self.classifier(fused)
        return logits

NUM_CLASSES = len(top_answers) + 1

model = MedicalVilBERT(
    text_encoder=text_encoder,
    image_projection=image_projection,
    num_classes=NUM_CLASSES
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-5)

print("Model initialized successfully.")

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: bert-base-uncased
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 
cls.predictions.bias                       | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Model initialized successfully.


In [None]:
# PART 3: TRAINING + EVAL

EPOCHS = 3
best_test_acc = 0

for epoch in range(EPOCHS):
    print(f"\n========== Epoch {epoch+1}/{EPOCHS} ==========")

    # ---------------- TRAINING ----------------
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for batch in tqdm(train_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        image_features = []
        for img in batch["image"]:
            feat = extract_image_feature(img)
            image_features.append(feat)
        image_features = torch.stack(image_features).to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask, image_features)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_acc = train_correct / train_total
    avg_train_loss = train_loss / len(train_loader)

    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Train Accuracy: {train_acc*100:.2f}%")

    # ---------------- TESTING ----------------
    model.eval()
    test_loss = 0
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            image_features = []
            for img in batch["image"]:
                feat = extract_image_feature(img)
                image_features.append(feat)
            image_features = torch.stack(image_features).to(device)

            logits = model(input_ids, attention_mask, image_features)
            loss = criterion(logits, labels)

            test_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            test_correct += (preds == labels).sum().item()
            test_total += labels.size(0)

    test_acc = test_correct / test_total
    avg_test_loss = test_loss / len(test_loader)

    print(f"Test Loss: {avg_test_loss:.4f}")
    print(f"Test Accuracy: {test_acc*100:.2f}%")

    if test_acc > best_test_acc:
        best_test_acc = test_acc

print("\nBest Test Accuracy Achieved:", best_test_acc * 100)




100%|██████████| 225/225 [18:15<00:00,  4.87s/it]


Train Loss: 1.7319
Train Accuracy: 54.71%
Test Loss: 1.1575
Test Accuracy: 59.87%



100%|██████████| 225/225 [17:44<00:00,  4.73s/it]


Train Loss: 1.2661
Train Accuracy: 64.25%
Test Loss: 1.0272
Test Accuracy: 62.97%



100%|██████████| 225/225 [17:44<00:00,  4.73s/it]


Train Loss: 1.1065
Train Accuracy: 71.28%
Test Loss: 1.0325
Test Accuracy: 67.63%

Best Test Accuracy Achieved: 67.62749445676275
