In [None]:
# install necessary libraries
!pip install openai
!pip install torch
!pip install salesforce-lavis
!pip install transformers

In [None]:
# imports
import torch
from PIL import Image
from lavis.models import load_model_and_preprocess
from transformers import ViltProcessor, ViltForQuestionAnswering
import requests
from transformers import AutoProcessor, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# load vqa variants used in this experiment
model_blip_vqa2, vis_processors_blip_vqa2, txt_processors_blip_vqa2 = load_model_and_preprocess(name="blip_vqa", model_type="vqav2", is_eval=True, device=device)
processor_vilt = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model_vilt = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
processor_git_vqav2 = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
model_git_vqav2 = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")

In [None]:
# read cub data (source: falcon)
import os
import json

f= open('/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub_fewshot/questions.json','r')
data = json.load(f)
print(len(data))

conceptlist = []

for i in data:
  trtxt = i['train_sample']['text']
  trind = i['train_sample']['image_index']
  trans = i['train_sample']['answer']
  vltxt = i['val_sample']['text']
  vlind = i['val_sample']['image_index']
  vlans = i['val_sample']['answer']
  assert len(trind)==1
  assert len(vlind)==30
  concept = trtxt[0].lower().replace("there is a ", "").replace(".","")

  if concept not in conceptlist:
    conceptlist.append(concept)

#print("#unique concepts",len(conceptlist))
print(conceptlist)

In [None]:
# calls gpt with given propmpt
apikey = "<TODO>" # you should paste your api key here
import os
import openai
openai.api_key = apikey
openai.Model.list()

gpt3desc = {}

def callgpt(concept,numconcepts):
  inpprompt = "Describe in "+numconcepts+" phrases separated by # -- how the "+concept+" looks like"
  #print("Prompt to LLM ->", inpprompt)
  response = openai.Completion.create(model="text-davinci-003",prompt=inpprompt,max_tokens=32,temperature=0.25)
  desc = [l.strip().lower() for l in " ".join(response['choices'][0]['text'].strip().split()).split("#") if l]
  print(desc)
  #assert(len(desc)==3)
  #sample = ['A','B','C']
  return desc

m=1

In [None]:
# collects 1 description per concept
for h in conceptlist[0:200]:
  gpt3desc[h] = callgpt(h,"one")

with open("conceptdb1.json", "w") as write_file:
    json.dump(gpt3desc, write_file, indent=4)

m=3

In [None]:
# collects 3 descriptions per concept
for h in conceptlist[0:200]:
  gpt3desc[h] = callgpt(h,"three")

with open("conceptdb3.json", "w") as write_file:
    json.dump(gpt3desc, write_file, indent=4)

m=5

In [None]:
# collects 5 descriptions per concept
for h in conceptlist[0:200]:
  gpt3desc[h] = callgpt(h,"five")

with open("conceptdb5.json", "w") as write_file:
    json.dump(gpt3desc, write_file, indent=4)

In [None]:
# only for visualization purposes - prints
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

def mgrid(inlist):
  imgarr = []
  for im in inlist:
    imgarr.append(plt.imread(im))
  fig = plt.figure(figsize=(6., 6.))
  grid = ImageGrid(fig, 111,
                  nrows_ncols=(int((len(inlist)/5)+1),5),
                  axes_pad=0.1,
                  )

  for ax, im in zip(grid, imgarr):
      ax.set_xticklabels([])
      ax.set_yticklabels([])
      ax.imshow(im)

  plt.show()

In [None]:
# blip vqa
def callblip(pathofimag, question, answer_candidates):
    raw_image = Image.open(pathofimag).convert("RGB")
    image_blip_vqa2 = vis_processors_blip_vqa2["eval"](raw_image).unsqueeze(0).to(device)
    question_blip_vqa2 = txt_processors_blip_vqa2["eval"](question)
    samples_blip_vqa2 = {"image": image_blip_vqa2, "text_input": question_blip_vqa2}
    return model_blip_vqa2.predict_answers(samples_blip_vqa2, answer_list=answer_candidates, inference_method="rank")[0]

In [None]:
# vilt vqa
def callvilt(pathofimag, question, answer_candidates):
    raw_image = Image.open(pathofimag).convert("RGB")
    encoding = processor_vilt(raw_image, question, return_tensors="pt")
    outputs = model_vilt(**encoding)
    logits = outputs.logits
    idx = logits.argmax(-1).item()
    return model_vilt.config.id2label[idx]

In [None]:
# gitvqa
def callgit(pathofimag, question, answer_candidates):
    raw_image = Image.open(pathofimag).convert("RGB")
    pixel_values = processor_git_vqav2(images=raw_image, return_tensors="pt").pixel_values
    input_ids = processor_git_vqav2(text=question, add_special_tokens=False).input_ids
    input_ids = [processor_git_vqav2.tokenizer.cls_token_id] + input_ids
    input_ids = torch.tensor(input_ids).unsqueeze(0)
    generated_ids = model_git_vqav2.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
    return processor_git_vqav2.batch_decode(generated_ids, skip_special_tokens=True)[0].lstrip(question)

m=1,blipvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb1.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_blipvqa_m1.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==1)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callblip(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

m=3,blipvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb3.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_blipvqa_m3.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==3)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callblip(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

m=5,blipvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb5.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_blipvqa_m5.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==5)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callblip(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

m=1,viltvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb1.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_viltvqa_m1.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==1)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callvilt(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

m=3,viltvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb3.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_viltvqa_m3.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==3)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callvilt(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

m=5,viltvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb5.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_viltvqa_m5.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==5)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callvilt(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

m=1,gitvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb1.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_gitvqa_m1.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==1)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callgit(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

m=3,gitvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb3.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_gitvqa_m3.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==3)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callgit(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

m=5,gitvqa

In [None]:
import json
imdir = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/0/cub/images.json'
impath = '/content/drive/MyDrive/FALCON-Release-master/DATASETS/CUB-200-2011/images/'
import re

fi= open(imdir,'r')
imfildata = json.load(fi)
#print(len(imfildata))

read_file = open("conceptdb5.json", "r")
conceptload = json.load(read_file)
#print(conceptload['crested auklet'])

concept = ""
c=0
with open("results_gitvqa_m5.json", "w") as write_file2:
  for i in data:
    #if c<=5:
    qlist = []
    trtxt = i['train_sample']['text']
    trind = i['train_sample']['image_index']
    trans = i['train_sample']['answer']
    print(c, trtxt)

    if len(trind)==1:
      concept = trtxt[0].lower().replace("there is a ", "").replace(".","")
    else:
      print("something not right")

    desc= conceptload[concept]
    assert(len(desc)==5)

    for d in desc:
      qlist.append("Does this bird have "+d+"?")

    vltxt = i['val_sample']['text']
    vlind = i['val_sample']['image_index']
    vlans = i['val_sample']['answer']

    retans = []
    predlabel = []
    if len(vlind)==30:
      pathlist = []
      correct = 0
      for p in range(0,len(vlind)):
        pathlist.append(impath+imfildata[vlind[p]])
        retans.append([callgit(impath+imfildata[vlind[p]],q,["Yes","No"]) for q in qlist])
      #mgrid(pathlist)
      #print("QA predictions ->", retans)

      for r in retans:
        res = max(set(r), key = r.count)
        if res=="Yes":
          predlabel.append(True)
        elif res=="No":
          predlabel.append(False)
      #print("Predicted answers ->", predlabel)

      assert len(vlans)==len(predlabel)
      for l1,l2 in zip(vlans,predlabel):
        if l1==l2:
          correct+=1
      print("Correct predictions ->",correct,"/",len(vlans))
    else:
      print("something not right -- val mismatch")
    print("=================")
    c+=1
    resdump = {"concept":concept, "gpt3concepts":desc, "imlist":pathlist, "qlist":qlist, "gtans":vlans, "predans":predlabel, "correct": correct, "total":len(vlans)}
    json.dump(resdump, write_file2, indent=4)
    #else:
    #  break

In [None]:
# print concept-wise (class-wise) performance (%)
# add commas between results json file
import json

fres= open('/content/<TODO>','r') # change the name of the file you wish to analyze
# ['results_blipvqa_m1.json',results_blipvqa_m3.json','results_blipvqa_m5.json','results_viltvqa_m1.json','results_viltvqa_m3.json','results_viltvqa_m5.json','results_gitvqa_m1.json','results_gitvqa_m3.json','results_gitvqa_m5.json']
resultsdata = json.load(fres)
print(len(resultsdata))

conceptresmap = {}
concepttotalmap = {}
accuracy = {}

for u in resultsdata:
  if u['concept'] not in conceptresmap:
    conceptresmap[u['concept']] = u['correct']
    concepttotalmap[u['concept']] = u['total']
  else:
    conceptresmap[u['concept']] += u['correct']
    concepttotalmap[u['concept']] += u['total']

print(len(conceptresmap))
print(len(concepttotalmap))

print("total accuracy")
print(sum(conceptresmap.values()))
print(sum(concepttotalmap.values()))
print(sum(conceptresmap.values())*100/sum(concepttotalmap.values()))
print("class name", "correct pred", "total pred", "class accuracy")
for k in conceptlist:
  try:
    accuracy[k] = round(conceptresmap[k]*100/concepttotalmap[k],2)
    print(k, conceptresmap[k], concepttotalmap[k], conceptresmap[k]*100/concepttotalmap[k])
  except:
    continue

In [None]:
# print concept-wise fn/fp
# fn - actual true, predicted false
# fp - actual false, predicted true

print(len(resultsdata))

conceptgt = {}
conceptpred = {}
accuracy = {}

for u in resultsdata:
  if u['concept'] not in conceptgt:
    conceptgt[u['concept']] = u['gtans']
    conceptpred[u['concept']] = u['predans']
  else:
    conceptgt[u['concept']].extend(u['gtans'])
    conceptpred[u['concept']].extend(u['predans'])
  #print(len(conceptgt[u['concept']]))
  #print(len(conceptgt[u['concept']]))

fplist = {}
fnlist = {}

for k in conceptlist:
  try:
    fn = len([s for s in range(0,len(conceptgt[k])) if conceptgt[k][s]==True and conceptpred[k][s]==False])
    fp = len([s for s in range(0,len(conceptgt[k])) if conceptgt[k][s]==False and conceptpred[k][s]==True])
    fplist[k] = fp
    fnlist[k] = fn
    print(len(conceptgt[k]), len(conceptpred[k]), k, fn, fp)
  except:
    continue

In [None]:
# fpplot
import pandas as pd
from matplotlib import pyplot as plt
fig, ax = plt.subplots(figsize =(5, 70))
ax.barh(list(fplist.keys()), list(fplist.values()))
for s in ['top', 'bottom', 'left', 'right']:
    ax.spines[s].set_visible(False)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
ax.xaxis.set_tick_params(pad = 5)
ax.yaxis.set_tick_params(pad = 10)
ax.grid(color ='grey',linestyle ='-.', linewidth = 0.5,alpha = 0.2)
plt.show()

In [None]:
# fnplot
import pandas as pd
from matplotlib import pyplot as plt
fig, ax = plt.subplots(figsize =(5, 70))
ax.barh(list(fnlist.keys()), list(fnlist.values()))
for s in ['top', 'bottom', 'left', 'right']:
    ax.spines[s].set_visible(False)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
ax.xaxis.set_tick_params(pad = 5)
ax.yaxis.set_tick_params(pad = 10)
ax.grid(color ='grey',linestyle ='-.', linewidth = 0.5,alpha = 0.2)
plt.show()