In [1]:
import os

import requests
from transformers import Blip2Processor, BlipForQuestionAnswering, Blip2Model
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle


processor = AutoProcessor.from_pretrained('Salesforce/blip-vqa-base')
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
# model.load_state_dict(torch.load("blip/blip_weights.pth"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

torch.cuda.empty_cache()
torch.manual_seed(42)

class VQADataset(torch.utils.data.Dataset):
    """VQA (v2) dataset."""

    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # get image + text
        question = self.dataset[idx]['question']
        answer = self.dataset[idx]['answer']
        image_id = self.dataset[idx]['image_id']
        image_path = #image path
        image = Image.open(image_path).convert("RGB")
        text = question
        # print(image_id)
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        labels = self.processor.tokenizer.encode(
            answer, max_length= 8, pad_to_max_length=True, return_tensors='pt'
        )
        # print(labels)
        encoding["labels"] = labels
        # remove batch dimension
        for k,v in encoding.items():  encoding[k] = v.squeeze()
        return encoding


In [None]:
training_dataset = load_dataset("json", data_files="path/to/train/annotations", split="train[:70%]")
valid_dataset = load_dataset("json", data_files= "path/to/train/annotations", split = "train[30%:]")

In [None]:
print(training_dataset)

In [None]:
print("Training sets: {} - Validating set: {}".format(len(training_dataset), len(valid_dataset)))

In [None]:
train_dataset = VQADataset(dataset=training_dataset,
                          processor=processor)
valid_dataset = VQADataset(dataset=valid_dataset,
                          processor=processor)
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)


optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)

In [None]:
num_epochs = 100
scaler = torch.cuda.amp.GradScaler()

for epoch in range(num_epochs):
    epoch_loss = 0
    model.train()
    for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training'), train_dataloader):
        input_ids = batch.pop('input_ids').to(device)
        pixel_values = batch.pop('pixel_values').to(device)
        attention_masked = batch.pop('attention_mask').to(device)
        labels = batch.pop('labels').to(device)
        
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        attention_mask=attention_masked,
                        labels=labels)
        loss = outputs.loss
        epoch_loss += loss.item()
        optimizer.zero_grad()
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    model.eval()
    eval_loss = 0
    for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validation'), valid_dataloader):
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(input_ids=batch.pop('input_ids').to(device),
                        pixel_values=batch.pop('pixel_values').to(device),
                        attention_mask=batch.pop('attention_mask').to(device),
                        labels=batch.pop('labels').to(device))
        loss = outputs.loss
        eval_loss += loss.item()
    print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch+1, epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
    scheduler.step()
    

In [3]:
import json
question_file =#test questions file

with open(question_file, 'r') as input_file:
    questions = json.load(input_file)

temp = questions['questions']
with open("file.json", "w") as file:
    # Dump the data into the file as JSON
    json.dump(temp, file, indent=4)

In [4]:
qs = []
for question in questions["questions"]:
    qs.append(question['question'])
imgs = []
for question in questions["questions"]:
    imgs.append(question["image_id"])

In [None]:
print(len(qs))
print(len(imgs))

In [6]:
import json
annotation_file = #test annotations file

with open(annotation_file, 'r') as input_file:
    annotations = json.load(input_file)

temp = annotations['annotations']
with open("file.json", "w") as file:
    # Dump the data into the file as JSON
    json.dump(temp, file, indent=4)

In [7]:
answers = []
for annotation in annotations['annotations']:
    answers.append(annotation['multiple_choice_answer'])

In [None]:
print(len(answers))

In [None]:
from PIL import Image
import requests
from transformers import AutoProcessor, BlipForQuestionAnswering
import torch

model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
model.load_state_dict(torch.load("blip/blip_weights.pth"))
processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
res = []
correct = 0
for idx in range(len(qs)):
    image_path = #path to image
    image = Image.open(image_path).convert("RGB")
    text = qs[idx]
    inputs = processor(images=image, text=text, return_tensors="pt")
    outputs = model.generate(**inputs)
    print(idx)
    out = (processor.decode(outputs[0], labels = answers[idx], skip_special_tokens=True))
    res.append(out)
    print(res[idx], answers[idx])
    if(res[idx]==answers[idx]):
        correct=correct+1

In [None]:
print(correct)
print(len(res))
print(correct/(len(res)))