In [1]:
import torch
from torch.optim import Adam
import torch.nn as nn
from torchvision import models
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import pandas as pd
from PIL import Image
import json
from collections import Counter
from tqdm import tqdm

In [6]:
train_q = json.load(open('data/questions/train.json'))
train_ann = json.load(open('data/annotations/train.json'))

questions_data = train_q['questions'] if isinstance(train_q, dict) else train_q
annotations_data = train_ann['annotations'] if isinstance(train_ann, dict) else train_ann

print("First 3 question IDs:", [q['question_id'] for q in questions_data[:3]])
print("First 3 annotation question IDs:", [ann['question_id'] for ann in annotations_data[:3]])
print()
print("Are they aligned?", questions_data[0]['question_id'] == annotations_data[0]['question_id'])

First 3 question IDs: [393223003, 393227001, 131074002]
First 3 annotation question IDs: [520737002, 118989008, 564515004]

Are they aligned? False


In [10]:
def load_vqa_data(questions_file, annotations_file):
    questions_json = json.load(open(questions_file))
    annotations_json = json.load(open(annotations_file))
    
    questions_data = questions_json['questions']
    annotations_data = annotations_json['annotations']
    
    ann_dict = {ann['question_id']: ann for ann in annotations_data}
    
    data_list = []
    for q in questions_data:
        q_id = q['question_id']
        if q_id in ann_dict:
            data_list.append({
                'question': q['question'],
                'answer': ann_dict[q_id]['multiple_choice_answer'],
                'image_id': q['image_id']
            })
    
    return pd.DataFrame(data_list)

train_df = load_vqa_data('data/questions/train.json', 'data/annotations/train.json')
test_df = load_vqa_data('data/questions/test.json', 'data/annotations/test.json')
val_df = load_vqa_data('data/questions/val.json', 'data/annotations/val.json')

print(f"Train samples: {len(train_df)}")
print(f"Test samples: {len(test_df)}")
print(f"Val samples: {len(val_df)}")

Train samples: 44375
Test samples: 21435
Val samples: 21435


In [11]:
train = train_df[['question', 'answer', 'image_id']]
test = test_df[['question', 'answer', 'image_id']]
val = val_df[['question', 'answer', 'image_id']]

In [12]:
def top_answers(data, folder='train'):
    ans_cnt = Counter(data['answer'])
    top_ans = [a for a, _ in ans_cnt.most_common(1000)]

    data = data[data['answer'].isin(top_ans)].reset_index(drop=True)

    data['image_path'] = data['image_id'].apply(lambda x: f'data/images/{folder}/COCO_{folder}2014_{x:012d}.jpg')
    
    return data

In [13]:
tokenizer_model = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_questions(questions, max_len=20):
    encoded = tokenizer_model(
        questions,
        padding='max_length',
        truncation=True,
        max_length=max_len,
        return_tensors='pt'
    )
    return encoded['input_ids'].tolist(), encoded['attention_mask'].tolist()

In [14]:
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [15]:
class VQADataset(Dataset):
    def __init__(self, image, question, attention_mask, answer, transform):
        self.image = image
        self.question = question
        self.attention_mask = attention_mask
        self.answer = answer
        self.transform = transform

    def __getitem__(self, idx):
        image = Image.open(self.image[idx]).convert('RGB')
        image = self.transform(image)

        question = torch.tensor(self.question[idx], dtype=torch.long)
        mask = torch.tensor(self.attention_mask[idx], dtype=torch.long)
        answer = torch.tensor(self.answer[idx], dtype=torch.long)
        
        return image, question, mask, answer

    def __len__(self):
        return len(self.question)

In [16]:
from transformers import BertModel

class VQAModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_ans):
        super().__init__()
        
        resnet18 = models.resnet18(pretrained=True)
        self.img_encoder = nn.Sequential(*list(resnet18.children())[:-1])

        for name, param in self.img_encoder.named_parameters():
            if 'layer4' not in name:
                param.requires_grad = False

        self.img_fc = nn.Sequential(
            nn.Linear(512, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.bert = BertModel.from_pretrained('bert-base-uncased')

        for param in self.bert.parameters():
            param.requires_grad = False

        self.fusion = nn.Sequential(
            nn.Linear(768 + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim//2, num_ans)
        )

    def forward(self, question, mask, image):
        img_encod = self.img_encoder(image)
        img_encod = img_encod.view(img_encod.size(0), -1)
        img_encod = self.img_fc(img_encod)
    
        bert_output = self.bert(input_ids=question, attention_mask=mask)
        q_features = bert_output.pooler_output 
    
        combined = torch.cat([img_encod, q_features], dim=1)
        return self.classifier(self.fusion(combined))

In [17]:
train = top_answers(train, folder='train')
test = top_answers(test, folder='test')
val = top_answers(val, folder='val')

In [18]:
import os

def completeData(dataset, image_dir, folder):
    def image_exists(img_id):
        filename = f"COCO_{folder}2014_{int(img_id):012d}.jpg"
        path = os.path.join(image_dir, filename)
        return os.path.exists(path)

    return dataset[dataset['image_id'].apply(image_exists)]


In [19]:
train = completeData(train, 'data/images/train/', 'train')
val = completeData(val, 'data/images/val/', 'val')

train.reset_index(drop=True, inplace=True)
val.reset_index(drop=True, inplace=True)

print(train.shape)
print(val.shape)

(20399, 4)
(19107, 4)


In [40]:
train_ques = train['question'].tolist()
train_ans = train['answer'].tolist()
train_imgs = train['image_path'].tolist()

test_ques = test['question'].tolist()
test_ans = test['answer'].tolist()
test_imgs = test['image_path'].tolist()

val_ques = val['question'].tolist()
val_ans = val['answer'].tolist()
val_imgs = val['image_path'].tolist()

In [41]:
train_ques_set = sorted(set(train_ques))
max_len = 0
idx = 0
for cnt, i in enumerate(train_ques_set):
    q = i.split()
    if max_len < len(q):
        max_len = len(q)
        idx = cnt
print(max_len, " ", idx)
train_ques_set[9221]

19   2032


'Is this picture tilted?'

In [42]:
unique_ans = sorted(set(train_ans))

ans_to_idx = {answer: idx for idx, answer in enumerate(unique_ans)}
num_ans = len(ans_to_idx)

train_tokenized_ans = [ans_to_idx[a] for a in train_ans]

In [43]:
val_tokenized_ans = []
val_filtered_ques = []

cnt = 847
for idx, ans in enumerate(val_ans):
    if ans in ans_to_idx:
        val_tokenized_ans.append(ans_to_idx[ans])
        val_filtered_ques.append(val_ques[idx])

val_ques = val_filtered_ques

In [44]:
train_tokenized_ques, train_attention_masks = tokenize_questions(train_ques, max_len=20)
val_tokenized_ques, val_attention_masks = tokenize_questions(val_ques, max_len=20)

In [45]:
train_dataset = VQADataset(
    train_imgs, 
    train_tokenized_ques, 
    train_attention_masks,
    train_tokenized_ans, 
    train_transforms
)

train_data = DataLoader(
    train_dataset, 
    shuffle=True, 
    batch_size=32,
    pin_memory=True
)

val_dataset = VQADataset(
    val_imgs, 
    val_tokenized_ques, 
    val_attention_masks,
    val_tokenized_ans, 
    test_transforms
)

val_data = DataLoader(
    val_dataset, 
    batch_size=32, 
    pin_memory=True
)

In [46]:
# Model
vocab_size = tokenizer_model.vocab_size

model = VQAModel(vocab_size, 512, num_ans)



In [47]:
epochs = 50
learning_rate = 0.0001

In [48]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")

criterion_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

Using: cuda




In [49]:
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm  # Recommended for progress bars

model.to(device)
criterion_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
scaler = GradScaler()

best_val_loss = float('inf')
patience = 7
patience_counter = 0

print(f"Starting training on: {device}")

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0
    
    # tqdm adds a nice progress bar for each epoch
    train_loader = tqdm(train_data, desc=f"Epoch {epoch+1}/{epochs} [Train]")
    
    for batch_idx, (images, questions, masks, answers) in enumerate(train_loader):
        images = images.to(device)
        questions = questions.to(device)
        masks = masks.to(device)
        answers = answers.to(device)
        
        optimizer.zero_grad()
        
        with autocast():
            # Model now takes (question, mask, image)
            outputs = model(questions, masks, images) 
            loss = criterion_loss(outputs, answers)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        
        # Gradient clipping to prevent BERT's gradients from exploding
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        scaler.step(optimizer)
        scaler.update()
        
        # Metrics
        epoch_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += answers.size(0)
        correct += (predicted == answers).sum().item()
        
        # Update progress bar
        train_loader.set_postfix(loss=loss.item(), acc=100.*correct/total)
    
    avg_train_loss = epoch_loss / len(train_data)
    train_acc = 100 * correct / total
    
    # --- VALIDATION PHASE ---
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        val_loader = tqdm(val_data, desc=f"Epoch {epoch+1}/{epochs} [Val]")
        for images, questions, masks, answers in val_loader:
            images = images.to(device)
            questions = questions.to(device)
            masks = masks.to(device)
            answers = answers.to(device)
            
            with autocast():
                outputs = model(questions, masks, images)
                loss = criterion_loss(outputs, answers)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_total += answers.size(0)
            val_correct += (predicted == answers).sum().item()
            
            val_loader.set_postfix(loss=loss.item(), acc=100.*val_correct/val_total)
    
    avg_val_loss = val_loss / len(val_data)
    val_acc = 100 * val_correct / val_total
    
    # Update scheduler based on validation loss
    scheduler.step(avg_val_loss)
    
    # --- LOGGING & CHECKPOINTING ---
    print(f'\nSummary Epoch {epoch+1}:')
    print(f'Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%')
    print(f'Val   Loss: {avg_val_loss:.4f} | Val   Acc: {val_acc:.2f}%')
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'vqa_bert_resnet_best.pth')
        print(' Best model saved!')
    else:
        patience_counter += 1
        print(f'No improvement. Patience: {patience_counter}/{patience}')
        if patience_counter >= patience:
            print('Early stopping triggered!')
            break
    print('-' * 30)

print('\nTraining Complete!')

  scaler = GradScaler()


Starting training on: cuda


  with autocast():
Epoch 1/50 [Train]: 100%|████████████████████████████████████████| 638/638 [06:25<00:00,  1.66it/s, acc=20.9, loss=4.7]
  with autocast():
Epoch 1/50 [Val]: 100%|███████████████████████████████████████████| 570/570 [05:05<00:00,  1.86it/s, acc=22, loss=4.09]



Summary Epoch 1:
Train Loss: 5.0002 | Train Acc: 20.95%
Val   Loss: 4.4155 | Val   Acc: 21.96%
 Best model saved!
------------------------------


Epoch 2/50 [Train]: 100%|█████████████████████████████████████████| 638/638 [04:29<00:00,  2.37it/s, acc=22, loss=4.46]
Epoch 2/50 [Val]: 100%|███████████████████████████████████████████| 570/570 [03:16<00:00,  2.90it/s, acc=22, loss=4.03]



Summary Epoch 2:
Train Loss: 4.6813 | Train Acc: 21.98%
Val   Loss: 4.3794 | Val   Acc: 21.96%
 Best model saved!
------------------------------


Epoch 3/50 [Train]: 100%|█████████████████████████████████████████| 638/638 [04:23<00:00,  2.42it/s, acc=22, loss=5.23]
Epoch 3/50 [Val]: 100%|██████████████████████████████████████████████| 570/570 [03:21<00:00,  2.83it/s, acc=22, loss=4]



Summary Epoch 3:
Train Loss: 4.6449 | Train Acc: 21.98%
Val   Loss: 4.3692 | Val   Acc: 21.96%
 Best model saved!
------------------------------


Epoch 4/50 [Train]: 100%|███████████████████████████████████████| 638/638 [05:09<00:00,  2.06it/s, acc=21.5, loss=4.81]
Epoch 4/50 [Val]: 100%|███████████████████████████████████████████| 570/570 [03:30<00:00,  2.71it/s, acc=22, loss=4.08]



Summary Epoch 4:
Train Loss: 4.6313 | Train Acc: 21.45%
Val   Loss: 4.4125 | Val   Acc: 21.96%
No improvement. Patience: 1/7
------------------------------


Epoch 5/50 [Train]: 100%|████████████████████████████████████████| 638/638 [04:55<00:00,  2.16it/s, acc=21.5, loss=5.4]
Epoch 5/50 [Val]: 100%|███████████████████████████████████████████| 570/570 [03:29<00:00,  2.72it/s, acc=22, loss=4.05]



Summary Epoch 5:
Train Loss: 4.6151 | Train Acc: 21.52%
Val   Loss: 4.3719 | Val   Acc: 21.96%
No improvement. Patience: 2/7
------------------------------


Epoch 6/50 [Train]:  47%|██████████████████▎                    | 299/638 [02:15<02:33,  2.21it/s, acc=21.1, loss=4.76]


KeyboardInterrupt: 