In [1]:
# install requirements
import sys

#ran once for installation
# !pip3 install transformers==4.15.0 timm==0.4.12 fairscale==0.4.4
# !git clone https://github.com/salesforce/BLIP
# !pip install torchvision

%cd BLIP

/home/robotics1/Desktop/image-passage-answering-WebQA/BLIP


In [2]:
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_demo_image(image_size,device,imagefile):
    img_url = imagefile
    #the following is used when url is used for an imagefile
    #raw_image = Image.open(requests.get(imagefile, stream=True).raw).convert('RGB')   
    raw_image = Image.open(imagefile).convert('RGB') 

    w,h = raw_image.size
    # display(raw_image.resize((w//5,h//5)))
    
    transform = transforms.Compose([
        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ]) 
    image = transform(raw_image).unsqueeze(0).to(device)   
    return image



In [3]:
from models.blip_itm import blip_itm
image_size = 384
#url = "http://images.cocodataset.org/val2017/000000039769.jpg" #get an image file name and its path

imagepath = '/home/robotics1/Desktop/image-only-answering-WebQA/BLIP/'
image1 = imagepath+"cats.png" #specify the image file name

url = image1

image = load_demo_image(image_size=image_size,device=device, imagefile=url)

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
    
model = blip_itm(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device='cpu')

#specify each question and its answer choices
question = "What are the aminals in the image?"
answerchoices = ["pigs", "dogs", "cats", "monkeys"]
# question = "How many animals are in the image?"
# answerchoices = ["one", "two", "three", "four"]

#Define a function to merge the question and each answer
def merge_string(question, answer):
    text1 = question + " " + answer
    return text1

#Create a list of multiple texts, each of them containing the question and each answer
textlist = []
for eachanswer in answerchoices:   
    textlist.append(merge_string(question,eachanswer))

print(textlist)

#Use BLIP to process the given image and multiple texts. The scores will show their relevance between the image and each text.
scorelist=[]
for i in textlist:
  itm_output = model(image,i,match_head='itm')
  scorelist.append(float(torch.nn.functional.softmax(itm_output,dim=1)[:,1]))

print(scorelist)

#Check which answer choice has the highest score. 
#the final answer index is the answer given by the model. 
#Check if this is same as the answer in JSON file. 
#If they are the same (correct), store 1 for the problem in the spread sheet. 
#If they are not the same (not correct), sotre 0 for the problem in the spread sheet.

scoremax= max(scorelist)
print('final answer: ', answerchoices[scorelist.index(scoremax)])
print('final answer index: ', scorelist.index(scoremax))

load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
['What are the aminals in the image? pigs', 'What are the aminals in the image? dogs', 'What are the aminals in the image? cats', 'What are the aminals in the image? monkeys']
[0.00709642143920064, 0.010806050151586533, 0.2144784927368164, 0.00195474736392498]
final answer:  cats
final answer index:  2


In [4]:
#WEBQA
import json
import csv

f = open('randomizedWebQA.json')

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
model = blip_itm(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device='cpu')
initial_path = '/home/robotics1/Desktop/WEBQA/data_chunks/NewImages'
  
data = json.load(f)
#print(data[0])

error_qid_list = []

load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth


In [8]:


with open('WebWA_image_passage_QAP_finishing.csv', mode='w', newline='') as csvfile:
   writer = csv.writer(csvfile)

   #for index in range(0,len(data)):                                
   for index in range(8704,len(data)):
      print(index)

      qid = data[index]['qid']
      correct_answer = data[index]["answer"]
      file_specific_path = data[index]['images'][8:]
      final_path = initial_path + file_specific_path
      #print(final_path)
      passage = data[index]['passage']


      url = final_path

      try:
         image = load_demo_image(image_size=image_size,device=device, imagefile=url)
      except Exception as e:
         print(f"Error occurred: {e}")
         error_qid_list.append(index)
         continue


         


      question = data[index]['question'] + "?"
      answer_choices = data[index]['answer_choices']
      # print(question)
      # print(answer_choices)
      text_list = []
      for eachanswer in answer_choices:
         text_list.append(merge_string(question,eachanswer) + ". "+ passage)

      #print(text_list)

      score_list=[]
      for i in text_list:
         itm_output = model(image,i,match_head='itm')
         score_list.append(float(torch.nn.functional.softmax(itm_output,dim=1)[:,1]))

      #print(score_list)

      score_max= max(score_list)
      #print('final answer for index ', index,': ', answer_choices[score_list.index(score_max)])
      #print('final answer index: ', score_list.index(score_max))

      #list [qid, question, answer from JSON file, model generated answer index, model-generated answer text]
      appendList = [qid, question, correct_answer, score_list.index(score_max), answer_choices[score_list.index(score_max)]]
      #print(appendList)
      writer.writerow(appendList)

error_file = open("error_file_QAP.txt", "w")
for error_qid in error_qid_list:
   error_file.write(str(error_qid) +"\n")
error_file.close()

8704
8705
8706
8707
8708
8709
8710
8711
8712
8713
8714
8715
8716
8717
8718
8719
8720
8721
8722
8723
8724
8725
8726
8727
8728
8729
8730
8731
8732
8733
8734
8735
8736
8737
8738
8739
8740
8741
8742
8743
8744
8745
8746
8747
8748
8749
8750
8751
8752
8753
8754
8755
8756
8757
8758
8759
8760
8761
8762
8763
8764
8765
8766
8767
8768
8769
8770
8771
8772
8773
8774
8775
8776
8777
8778
8779
8780
8781
8782
8783
8784
8785
8786
8787
8788
8789
8790
8791
8792
8793
8794
8795
8796
8797
8798
8799
8800
8801
8802
8803
8804
8805
8806
8807
8808
8809
8810
8811
8812
8813
8814
8815
8816
8817
8818
8819
8820
8821
8822
8823
8824
8825
8826
8827
8828
8829
8830
8831
8832
8833
8834
8835
8836
8837
8838
8839
8840
8841
8842
8843
8844
8845
8846
8847
8848
8849
8850
8851
8852
8853
8854
8855
8856
8857
8858
8859
8860
8861
8862
8863
8864
8865
8866
8867
8868
8869
8870
8871
8872
8873
8874
8875
8876
8877
8878
8879
8880
8881
8882
8883
8884
8885
8886
8887
8888
8889
8890
8891
8892
8893
8894
8895
8896
8897
8898
8899
8900
8901
8902
8903


In [7]:
import json
import csv

f = open('randomizedWebQA.json')
data = json.load(f)
print(len(data))
print(data[1151]['qid'])
print(data[1151]['images'][8:])
print(data[8703])

9438
d5c0b9540dba11ecb1e81171463288e9
/d5c0b9540dba11ecb1e81171463288e9.png
{'qid': 'd5cdd38c0dba11ecb1e81171463288e9', 'images': './images/d5cdd38c0dba11ecb1e81171463288e9.png', 'multiple_images': True, 'passage': 'Left: Tokyo ginza shiseido building 2014 A photo of "Tokyo Ginza Shiseido Building"Ginza,chuo-ku,Tokyo,Japan.\nRight: Nittetsu Kobiki Building, at Ginza, Chuo, Tokyo (2019-01-02) 02 Nittetsu Kobiki Building (), located at 7-16-3 Ginza, Chuo, Tokyo, Japan', 'question': '"Do the Tokyo Ginza Shiseido Building and the Nittetsu Kobiki Building each have fewer than 20 windows?"', 'answer_choices': ['"No, the Tokyo Ginza Shiseido Building and the Nittetsu Kobiki Building do not have fewer than 20 windows."', '"No, the Tokyo Ginza Shiseido Building and the Nittetsu Kobiki Building have more than 20 windows."'], 'answer': 0, 'image_type': 'Natural', 'image subtype': '', 'answer_type': '4way_text', 'multistep_inference': False, 'reasoning_type': ['InfoLookup', 'Abductive', 'Logical']