In [1]:
import json
import torch
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from num2words import num2words
from ofa.ofa_infer import OFAInference
from evaluate_metrics import compute_f1
from lavis.models import load_model_and_preprocess
from transformers import ViltProcessor, ViltForQuestionAnswering

## OFA 

In [2]:
# ofa = OFAInference()

ofa = OFAInference(pretrained_path='models/ofa_huge.pt')

def infer_ofa(image_path, question):
    answer = ofa.ofa_inference(image_path, question)
    split_ans = answer.split()
    ans = []
    for w in split_ans:
        try:
            ans.append(num2words(w))
        except:
            ans.append(w)
    return ' '.join(ans)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  "Argument interpolation should be of type InterpolationMode instead of int. "


## LAVIS

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model, vis_processors, txt_processors = load_model_and_preprocess(
    name="blip_vqa", model_type="aokvqa", is_eval=True, device=device)

def infer_lavis(image_path, question):
    raw_image = Image.open(image_path).convert("RGB")
    image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
    question = txt_processors["eval"](question)
    answer = model.predict_answers(
        samples={"image": image, "text_input": question}, inference_method="generate")
    
    split_ans = answer[0].split()
    ans = []
    for w in split_ans:
        try:
            ans.append(num2words(w))
        except:
            ans.append(w)
    return ' '.join(ans)

## Public-test

In [15]:
with open('data/test/evjvqa_public_test-lang-qtype-answer.json', 'r', encoding='utf-8') as f:
    test_data = json.load(f)
    
annotations = test_data['annotations']

print(len(annotations))

gold_answers = []
ofa_answers = []
vilt_answers = []
lavis_answers = []


for anno in tqdm(annotations):
#     if anno['question_type'] in ['WHAT_COLOR'] and anno['language'] == 'en':
    if anno['language'] == 'en':
#         vilt_answers.append(infer_vilt(anno['img_path'], anno['question']))
#         ofa_answers.append(infer_ofa(anno['img_path'], anno['question']))
        lavis_answers.append(infer_lavis(anno['img_path'], anno['question']))
        gold_answers.append(anno['answer'])
        
i = 0
gold_dict = {}
ofa_dict = {}
vilt_dict = {}
lavis_dict = {}

for j, anno in tqdm(enumerate(annotations)):
#     if anno['question_type'] in ['WHAT_COLOR'] and anno['language'] == 'en':
    if anno['language'] == 'en':
        idx = annotations[j]['id']
        gold_dict[idx] = gold_answers[i]
#         ofa_dict[idx] = ofa_answers[i]
#         vilt_dict[idx] = vilt_answers[i]
        lavis_dict[idx] = lavis_answers[i]
        
        i += 1

i

5015

In [18]:
with open('./outputs/results-lavis-aokvqa.json', 'w', encoding='utf-8') as f:
    json.dump(lavis_dict, f, indent=4, ensure_ascii=False)

## Private-test

In [12]:
with open('data/private-test/evjvqa_private_test-desc-lang-qtype.json', 'r', encoding='utf-8') as f:
    ptest_data = json.load(f)
    
annotations = ptest_data['annotations']

print(len(annotations))

gold_answers = []
ofa_answers = []
vilt_answers = []
lavis_answers = []


for anno in tqdm(annotations):
#     if anno['question_type'] in ['WHAT_COLOR'] and anno['language'] == 'en':
    if anno['language'] == 'en':
#         vilt_answers.append(infer_vilt(anno['img_path'], anno['question']))
#         ofa_answers.append(infer_ofa(anno['img_path'], anno['question']))
        lavis_answers.append(infer_lavis(anno['img_path'], anno['question']))
        gold_answers.append(anno['answer'])

10000


  0%|          | 0/10000 [00:00<?, ?it/s]

In [13]:
i = 0
gold_dict = {}
ofa_dict = {}
vilt_dict = {}
lavis_dict = {}

for j, anno in tqdm(enumerate(annotations)):
#     if anno['question_type'] in ['WHAT_COLOR'] and anno['language'] == 'en':
    if anno['language'] == 'en':
        idx = annotations[j]['id']
        gold_dict[idx] = gold_answers[i]
#         ofa_dict[idx] = ofa_answers[i]
        lavis_dict[idx] = lavis_answers[i]
        
        i += 1
        
i

0it [00:00, ?it/s]

3343

In [14]:
with open('./outputs/private-test/results-lavis-aokvqa.json', 'w', encoding='utf-8') as f:
    json.dump(lavis_dict, f, indent=4, ensure_ascii=False)