In [1]:
%pip install -q transformers datasets torch

Note: you may need to restart the kernel to use updated packages.


In [2]:

from datasets import load_dataset
from transformers import BertTokenizer, ViTFeatureExtractor
from torch.utils.data import DataLoader
from PIL import Image
import torch
import numpy as np
import itertools
import torch
import torch.nn as nn
from transformers import BertModel, ViTModel
import torch.optim as optim
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Initialize Tokenizer and Feature Extractor
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224", do_resize=True, size=224)

def preprocess_data(examples):
    image_paths = examples["image_id"]
    images = [Image.open(image_path).convert("RGB") for image_path in image_paths]  # Ensure images are in RGB
    texts = examples["question"]

    # Apply ViT Feature Extractor
    image_encodings = feature_extractor(images=images, return_tensors="pt")
    fimage_encodings = {k: v.squeeze(0) for k, v in image_encodings.items()}


    # Apply BERT Tokenizer
    text_encodings = tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt")
    text_encodings = {k: v.squeeze(0) for k, v in text_encodings.items()}


    targets = []
    for labels, scores in zip(examples["label.ids"], examples["label.weights"]):
        target = torch.zeros(len(id2label))
        for label, score in zip(labels, scores):
            target[label] = score
        targets.append(target)

    image_encodings["question"] = text_encodings["input_ids"]
    image_encodings["attention_mask"] = text_encodings["attention_mask"]
    image_encodings["labels"] = torch.stack(targets)

    return image_encodings



In [15]:
# Load datasets
train_dataset = load_dataset("Graphcore/vqa", split="train[:1000]")
val_dataset = load_dataset("Graphcore/vqa", split="validation[:20]")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [16]:
labels_train = [item['ids'] for item in train_dataset['label']]
labels_val = [item['ids'] for item in val_dataset['label']]
labels = labels_train + labels_val
flattened_labels = list(itertools.chain(*labels))
unique_labels = list(set(flattened_labels))

label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()} 

def replace_ids(inputs):
  inputs["label"]["ids"] = [label2id[x] for x in inputs["label"]["ids"]]
  return inputs


train_dataset = train_dataset.map(replace_ids)
flat_train_dataset = train_dataset.flatten()
flat_train_dataset.features

val_dataset = val_dataset.map(replace_ids)
flat_val_dataset = val_dataset.flatten()
flat_val_dataset.features


Map: 100%|██████████| 1000/1000 [00:00<00:00, 16398.54 examples/s]
Map: 100%|██████████| 20/20 [00:00<00:00, 4744.96 examples/s]


{'question': Value(dtype='string', id=None),
 'question_type': Value(dtype='string', id=None),
 'question_id': Value(dtype='int32', id=None),
 'image_id': Value(dtype='string', id=None),
 'answer_type': Value(dtype='string', id=None),
 'label.ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'label.weights': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)}

In [17]:
train_dataset = flat_train_dataset.map(preprocess_data, batched=True, remove_columns=["question_type", "question_id", "image_id", "answer_type"])
val_dataset = flat_val_dataset.map(preprocess_data, batched=True, remove_columns=["question_type", "question_id", "image_id", "answer_type"])


Map: 100%|██████████| 1000/1000 [00:10<00:00, 97.53 examples/s]
Map: 100%|██████████| 20/20 [00:00<00:00, 106.31 examples/s]


In [7]:
def collate_fn(batch):
    return {
        "image": torch.stack([torch.tensor(x["pixel_values"]) for x in batch]),
        "question": torch.stack([torch.tensor(x["question"]) for x in batch]),
        "attention_mask": torch.stack([torch.tensor(x["attention_mask"]) for x in batch]),
        "labels": torch.stack([torch.tensor(x["labels"]) for x in batch])
    }

In [18]:
batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


In [20]:
class VQAViTBert(nn.Module):
    def __init__(self):
        super(VQAViTBert, self).__init__()
        self.vision_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.text_model = BertModel.from_pretrained("bert-base-uncased")
        self.fc = nn.Sequential(
            nn.Linear(self.vision_model.config.hidden_size + self.text_model.config.hidden_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, len(id2label))  # Assuming 1000 possible answers
        )

    def forward(self, image, question, attention_mask):
        img_feats = self.vision_model(image).pooler_output
        txt_feats = self.text_model(input_ids=question, attention_mask=attention_mask).pooler_output
        combined = torch.cat((img_feats, txt_feats), dim=-1)
        return self.fc(combined)

# Initialize Model
model = VQAViTBert()


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def train_model(model, train_dataloader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for batch in tqdm(train_dataloader):
            image = batch["image"].to(device)
            question = batch["question"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            output = model(image, question, attention_mask)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_dataloader)}")

train_model(model, train_dataloader, criterion, optimizer, epochs=1)



100%|██████████| 125/125 [21:47<00:00, 10.46s/it]

Epoch 1/1, Loss: 6.45198921585083





In [22]:
def evaluate_model(model, val_dataloader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in val_dataloader:
            image = batch["image"].to(device)
            question = batch["question"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["labels"].to(device)
            
            _,labels = torch.max(label, 1)

            output = model(image, question, attention_mask)
            loss = criterion(output, label)
            running_loss += loss.item()

            _, predicted = torch.max(output, 1)
            correct += (predicted == labels).sum().item()
            total += label.size(0)

    print(f"Validation Loss: {running_loss/len(val_dataloader)}")
    print(f"Validation Accuracy: {100 * correct / total}%")

evaluate_model(model, val_dataloader, criterion)


Validation Loss: 5.378966013590495
Validation Accuracy: 15.0%


In [None]:
def predict(image_path, question, model, tokenizer, feature_extractor):
    image = feature_extractor(images=Image.open(image_path), return_tensors="pt")["pixel_values"]
    question = tokenizer(question, return_tensors="pt", padding="max_length", truncation=True)

    model.eval()
    with torch.no_grad():
        output = model(
            image.to(device),
            question["input_ids"].to(device),
            question["attention_mask"].to(device)
        )
        predicted_label = torch.argmax(output, dim=-1).item()
        return predicted_label

# image_path = "/path/to/image.jpg"
# question = "What is in the image?"
# predicted_answer = predict(image_path, question, model, tokenizer, feature_extractor)
# print(f"Predicted Answer: {predicted_answer}")
