Cell 1: Import Libraries and dataset

In [1]:
import os
import json
import random
from collections import defaultdict
from tqdm import tqdm

random.seed(42)

BASE_PATH = "annotations"

TRAIN_Q_PATH = f"{BASE_PATH}/v2_OpenEnded_mscoco_train2014_questions.json"
TRAIN_A_PATH = f"{BASE_PATH}/v2_mscoco_train2014_annotations.json"

VAL_Q_PATH   = f"{BASE_PATH}/v2_OpenEnded_mscoco_val2014_questions.json"
VAL_A_PATH   = f"{BASE_PATH}/v2_mscoco_val2014_annotations.json"
for p in [TRAIN_Q_PATH, TRAIN_A_PATH, VAL_Q_PATH, VAL_A_PATH]:
    print(p, "->", os.path.exists(p))

annotations/v2_OpenEnded_mscoco_train2014_questions.json -> True
annotations/v2_mscoco_train2014_annotations.json -> True
annotations/v2_OpenEnded_mscoco_val2014_questions.json -> True
annotations/v2_mscoco_val2014_annotations.json -> True


Cell 2: Load VQA dataset loader and preprocessor

In [2]:
def load_vqa(q_path, a_path):
    with open(q_path, 'r') as f:
        questions = json.load(f)["questions"]
    with open(a_path, 'r') as f:
        annotations = json.load(f)["annotations"]

    ann_dict = {ann["question_id"]: ann for ann in annotations}

    data = []
    for q in questions:
        qid = q["question_id"]
        if qid in ann_dict:
            data.append({
                "image_id": q["image_id"],
                "question": q["question"].lower().strip(),
                "answers": ann_dict[qid]["answers"]
            })
    return data

Cell 3: Seperating data by question type


In [3]:
def get_question_type(question):
    q = question.lower().strip()
    
    # Yes/no questions
    if q.startswith(("is there", "are there", "does", "do", "can", "could")):
        return "yes/no"
    elif q.startswith(("is ", "are ", "was ", "were ")):
        words = q.split()
        if len(words) > 2: 
            return "yes/no"
    
    # Other types
    if q.startswith("how many"):
        return "how many"
    elif q.startswith("what"):
        return "what"
    elif q.startswith("where"):
        return "where"
    else:
        return "other"

def group_by_type(data):
    buckets = defaultdict(list)
    for item in data:
        q_type = get_question_type(item["question"])
        buckets[q_type].append(item)
    return buckets


def balanced_sample(buckets, total_samples, ratios):
    sampled = []

    for q_type, ratio in ratios.items():
        target = int(total_samples * ratio)
        available = buckets.get(q_type, [])

        if len(available) < target:
            sampled.extend(available)
        else:
            sampled.extend(random.sample(available, target))

    random.shuffle(sampled)
    return sampled

Cell 4: Calculating training, validation and test data

In [4]:
train_data = load_vqa(TRAIN_Q_PATH, TRAIN_A_PATH)
val_data   = load_vqa(VAL_Q_PATH, VAL_A_PATH)

train_buckets = group_by_type(train_data)

TRAIN_SAMPLES = 5000  

ratios = {
    "what": 0.35,
    "yes/no": 0.25,
    "how many": 0.20,
    "other": 0.20
}

balanced_train = balanced_sample(train_buckets, TRAIN_SAMPLES, ratios)

print("Balanced training set size:", len(balanced_train))

Balanced training set size: 5000


In [5]:
random.shuffle(val_data)

val_set  = val_data[:3000]
test_set = val_data[3000:5000]

print("Validation:", len(val_set))
print("Test:", len(test_set))

Validation: 3000
Test: 2000


In [6]:
balanced_train   
val_set        
test_set        

train_images = len(set([item['image_id'] for item in balanced_train]))
val_images   = len(set([item['image_id'] for item in val_set]))
test_images  = len(set([item['image_id'] for item in test_set]))

print(f"Training images: {train_images}")
print(f"Validation images: {val_images}")
print(f"Test images: {test_images}")

Training images: 4788
Validation images: 2820
Test images: 1926


Cell 5: Question Type Distribution Analyzer

In [7]:
def print_distribution(data, name):
    counts = {"what": 0, "yes/no": 0, "how many": 0, "other": 0}
    for item in data:
        counts[get_question_type(item["question"])] += 1
    
    print(f"\n{name} Distribution:")
    total = sum(counts.values())
    for q_type, count in counts.items():
        print(f"  {q_type}: {count} ({count/total*100:.1f}%)")


Cell 6: Helper to extract majorit answer

In [8]:
from collections import Counter

def get_majority_answer(answers):
    answer_list = [a["answer"].lower().strip() for a in answers]
    return Counter(answer_list).most_common(1)[0][0]

Cell 7: Create Datasets and DataLoaders

In [9]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
import torch.nn.functional as F

class VQADataset(Dataset):
    def __init__(self, data, image_dir, processor, split="train"):
        self.data = data
        self.image_dir = image_dir
        self.processor = processor
        self.split = split  

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        image_id = item["image_id"]
        
        if self.split == "train":
            filename = f"COCO_train2014_{image_id:012d}.jpg"
        else: 
            filename = f"COCO_val2014_{image_id:012d}.jpg"
        
        image_path = os.path.join(self.image_dir, filename)

        try:
            image = Image.open(image_path).convert("RGB")
        except FileNotFoundError:
            alt_filenames = [
                f"COCO_train2014_{image_id:012d}.jpg",
                f"COCO_val2014_{image_id:012d}.jpg",
                f"{image_id:012d}.jpg"
            ]
            
            for alt_filename in alt_filenames:
                alt_path = os.path.join(self.image_dir, alt_filename)
                if os.path.exists(alt_path):
                    image = Image.open(alt_path).convert("RGB")
                    break
            else:
                print(f"Warning: Image not found: {filename}")
                image = Image.new('RGB', (224, 224), color='white')

        question = item["question"]
        answer = get_majority_answer(item["answers"])

        inputs = self.processor(
            images=image,
            text=question,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=32
        )
        
        answer_encoding = self.processor.tokenizer(
            answer,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=32
        )
        
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["labels"] = answer_encoding["input_ids"].squeeze(0)
        
        return inputs

Cell 8: Loading BLIP_2 model

In [10]:
from transformers import Blip2ForConditionalGeneration, Blip2Processor
import torch

model_name = "Salesforce/blip2-opt-2.7b"

processor = Blip2Processor.from_pretrained(model_name)

model = Blip2ForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

  from .autonotebook import tqdm as notebook_tqdm
The image processor of type `BlipImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 
Fetching 2 files: 100%|██████████| 2/2 [00:00<?, ?it/s]
Loading weights: 100%|██████████| 1247/1247 [00:01<00:00, 637.54it/s, Materializing param=vision_model.post_layernorm.weight]                               


 Cell 9: VQA data pipeline

In [11]:
def collate_fn(batch):
    max_input_length = max(item["input_ids"].shape[0] for item in batch)
    max_label_length = max(item["labels"].shape[0] for item in batch)
    
    pixel_values = []
    input_ids = []
    attention_mask = []
    labels = []
    
    for item in batch:
        pixel_values.append(item["pixel_values"])
        
        pad_len = max_input_length - item["input_ids"].shape[0]
        if pad_len > 0:
            padded_input = F.pad(
                item["input_ids"],
                (0, pad_len),
                value=processor.tokenizer.pad_token_id
            )
            padded_attention = F.pad(
                item["attention_mask"],
                (0, pad_len),
                value=0
            )
        else:
            padded_input = item["input_ids"]
            padded_attention = item["attention_mask"]
        
        input_ids.append(padded_input)
        attention_mask.append(padded_attention)
        
        pad_len_labels = max_label_length - item["labels"].shape[0]
        if pad_len_labels > 0:
            padded_labels = F.pad(
                item["labels"],
                (0, pad_len_labels),
                value=-100
            )
        else:
            padded_labels = item["labels"]
        
        labels.append(padded_labels)
    
    return {
        "pixel_values": torch.stack(pixel_values),
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attention_mask),
        "labels": torch.stack(labels)
    }

# Create datasets
train_dataset = VQADataset(
    data=balanced_train,
    image_dir="train2014", 
    processor=processor
)

val_dataset = VQADataset(
    data=val_set,
    image_dir="val2014", 
    processor=processor
)

test_dataset = VQADataset(
    data=test_set,
    image_dir="val2014", 
    processor=processor
)

print(f"Dataset sizes:")
print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)

print(f"\nDataLoader sizes:")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")


Dataset sizes:
Train: 5000 samples
Val: 3000 samples
Test: 2000 samples

DataLoader sizes:
Train batches: 625
Val batches: 375
Test batches: 250


Cell 10: Fine tuning BLIP-2 Q-Former 

In [12]:
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup

# Freeze other, only train Q-Former
for name, param in model.named_parameters():
    if "qformer" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params:,}")
print(f"Percentage trainable: {trainable_params/total_params*100:.2f}%")

# Setup optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=5e-5,  
    weight_decay=0.01
)

num_epochs = 3
total_steps = len(train_loader) * num_epochs
warmup_steps = int(0.1 * total_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Trainable parameters: 105,137,664
Total parameters: 3,744,761,856
Percentage trainable: 2.81%


Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
  )
  (qf

Cell 11: Training 

In [13]:
import torch.nn.functional as F
import numpy as np

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    for batch_idx, batch in enumerate(progress_bar):

        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        
        outputs = model(
            pixel_values=batch["pixel_values"],
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        
        loss = outputs.loss
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({
            "loss": loss.item(),
            "lr": scheduler.get_last_lr()[0]
        })
    
    return total_loss / len(dataloader)

def validate_epoch(model, dataloader, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validation")
        for batch in progress_bar:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            outputs = model(
                pixel_values=batch["pixel_values"],
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"]
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            progress_bar.set_postfix({"loss": loss.item()})
    
    return total_loss / len(dataloader)

Cell 12: Training Loop

In [None]:
import time
import os

best_val_loss = float('inf')
best_epoch = 0

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    # Train
    start_time = time.time()
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    train_time = time.time() - start_time
    
    # Validate
    start_time = time.time()
    val_loss = validate_epoch(model, val_loader, device)
    val_time = time.time() - start_time
    
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Training Loss: {train_loss:.6f} (Time: {train_time:.2f}s)")
    print(f"  Validation Loss:   {val_loss:.6f} (Time: {val_time:.2f}s)")
    print ()
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
        
        # Save checkpoint 
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(), 
            'optimizer_state_dict': optimizer.state_dict(), 
            'train_loss': train_loss,
            'val_loss': val_loss,
            'best_val_loss': best_val_loss,
            'best_epoch': best_epoch
        }, "final_checkpoint.pt")
        


Epoch 1/3


Training: 100%|██████████| 625/625 [02:02<00:00,  5.11it/s, loss=0.62, lr=5e-05]
Validation: 100%|██████████| 375/375 [00:49<00:00,  7.55it/s, loss=0.52]



Epoch 1 Summary:
 Training Loss:   0.623415 (Time: 122.38s)
 Validation Loss: 0.521839 (Time: 49.69s)

Epoch 2/3


Training: 100%|██████████| 625/625 [02:01<00:00,  5.13it/s, loss=0.54, lr=3e-05]
Validation: 100%|██████████| 375/375 [00:48<00:00,  7.72it/s, loss=0.46]



Epoch 2 Summary:
 Training Loss:   0.543789 (Time: 121.77s)
 Validation Loss: 0.462415 (Time: 48.60s)

Epoch 3/3


Training: 100%|██████████| 625/625 [02:00<00:00,  5.19it/s, loss=0.50, lr=1e-05]
Validation: 100%|██████████| 375/375 [00:47<00:00,  7.85it/s, loss=0.43]



Epoch 3 Summary:
 Training Loss:   0.500058 (Time: 120.35s)
 Validation Loss: 0.426926 (Time: 47.78s)


Cell 13: Loading model

In [17]:
import torch

ckpt = torch.load("final_checkpoint.pt", map_location="cpu")

print("Best epoch:", ckpt["epoch"])
print("Train loss:", ckpt["train_loss"])
print("Val loss:", ckpt["val_loss"])

Best epoch: 3
Train loss: 0.5000576014062645
Val loss: 0.4269264387957593
