In [None]:
import argparse
import json
import os
import random
import sys
from functools import partial
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
import torch.utils
import torch.nn.functional as F
from torch.utils.data import DataLoader, DistributedSampler

# PACKAGE_PARENT = ".."
# SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
# sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import main as detection
import util.dist as dist
import util.misc as utils
from datasets import build_dataset
from datasets.clevr import ALL_ATTRIBUTES
from engine import evaluate
from models import build_model
from util.metrics import MetricLogger

import datasets.transforms as T
from PIL import Image
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

from transformers import RobertaModel, RobertaTokenizerFast

## Settings for MDETR

In [None]:
def local_get_args_parser():
    detection_parser = detection.get_args_parser()
    parser = argparse.ArgumentParser(
        "Get predictions for clevr and dump to file", parents=[detection_parser], add_help=False
    )
    return parser

In [None]:
parser = argparse.ArgumentParser(
        "Get predictions for CLEVR and dump to file", parents=[local_get_args_parser()], add_help=False
    )
args = parser.parse_args(['--dataset_config', 'configs/clevr.json', '--resume', './clevr_checkpoint.pth'])

In [None]:
normalize = T.Compose([T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [None]:
def construct_sample(image, question):
    image = normalize(image,
                {
                    "boxes": torch.zeros(0, 4),
                    "labels": torch.zeros(0),
                    "iscrowd": torch.zeros(0),
                    "positive_map": torch.zeros(0),
                },)[0].unsqueeze(0)
    target = {
            "questionId": question["question_index"] if "question_index" in question else idx,
            "caption": question["question"],
        }
    captions = [target["caption"]]
    return image, captions

In [None]:
def infer_mdetr(img, question, model):
    sample = {"question": question}
    samples, captions = construct_sample(img, sample)
    samples = samples.to(device)
    memory_cache = model(samples, captions, encode_and_save=True)
    outputs = model(samples, captions, encode_and_save=False, memory_cache=memory_cache)
    answers = []
    answer_types = outputs["pred_answer_type"].argmax(-1)
    answer_types = [x.item() for x in answer_types]
    for i, ans_type in enumerate(answer_types):
        if ans_type == 0:
            answers.append("yes" if outputs["pred_answer_binary"][i].sigmoid() > 0.5 else "no")
        elif ans_type == 1:
            answers.append(ALL_ATTRIBUTES[outputs["pred_answer_attr"][i].argmax(-1).item()])
        elif ans_type == 2:
            answers.append(str(outputs["pred_answer_reg"][i].argmax(-1).item()))
        else:
            assert False, "must be one of the answer types"
    return answers[0]

In [None]:
# Update dataset specific configs
if args.dataset_config is not None:
    # https://stackoverflow.com/a/16878364
    d = vars(args)
    with open(args.dataset_config, "r") as f:
        cfg = json.load(f)
    d.update(cfg)

if args.mask_model != "none":
    args.masks = True

device = torch.device(args.device)
seed = args.seed + dist.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

checkpoint = torch.load(args.resume, map_location="cpu")

model_args = checkpoint["args"]
model_args.device = args.device

model_args.combine_datasets = ["clevr_question"]
for a in vars(args):
    if a not in vars(model_args):
        vars(model_args)[a] = vars(args)[a]

model, _, _, _, _ = build_model(model_args)
if "ema" in args and args.ema:
    assert "model_ema" in checkpoint
    model.load_state_dict(checkpoint["model_ema"])
else:
    model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
print("")

## Settings for ChatGPT

In [None]:
# Basic ChatGPT class
import openai

class BasicChatGPT:
    def __init__(self, openai_api_key, max_tokens=100):
        openai.api_key = openai_api_key
        self.max_tokens = max_tokens
        self.messages = []

    def call(self):
        response = openai.ChatCompletion.create(
          model="gpt-4",
          messages=self.messages,
          max_tokens=self.max_tokens,
          seed = 123,
        )
        return response

    def prompt(self, text):
        # Update the messages so far
        self.messages.append({
            "role": "user",
            "content": text,
        })

        # Call ChatGPT
        response = self.call()

        # Save the returned message
        message = response["choices"][0]["message"]
        self.messages.append(message)

        return message["content"]

## Load Dataset

In [None]:
ann = json.loads(open('CLEVR-Humans-val.json', 'r').read())
image_path = "CLEVR/CLEVR_v1.0/images/val/"

## Functions for CoQAH

In [None]:
STARTING_PROMPT = """ 
But you can't access the image and I can access the image.
You can ask me questions with these forms.
The question should be in []
If you can answer the above question, stop asking and give me an answer within 1 word.
The answer should be in {}
[is there a <SIZE> <COLOR> <MATERIAL> <SHAPE> <RELATION> <SIZE> <COLOR> <MATERIAL> <SHAPE> ?]
[what size is <SIZE> <COLOR> <MATERIAL> <SHAPE> <RELATION> <SIZE> <COLOR> <MATERIAL> <SHAPE> ?]
[what color is <SIZE> <COLOR> <MATERIAL> <SHAPE> <RELATION> <SIZE> <COLOR> <MATERIAL> <SHAPE> ?]
[what material is <SIZE> <COLOR> <MATERIAL> <SHAPE> <RELATION> <SIZE> <COLOR> <MATERIAL> <SHAPE> ?]
[what shape is <SIZE> <COLOR> <MATERIAL> <SHAPE> <RELATION> <SIZE> <COLOR> <MATERIAL> <SHAPE> ?]
[How many <SIZE> <COLOR> <MATERIAL> <SHAPE> <RELATION> <SIZE> <COLOR> <MATERIAL> <SHAPE> are there?]

<SIZE> :  [<EMPTY> or small or large]
<COLOR>:  [<EMPTY> or gray or red or blue or green or brown or purple or cyan or yellow]
<MATERIAL>: [<EMPTY> or rubber or metal]
<SHAPE>: [<EMPTY> or cube or sphere or cylinder or object]
<RELATION>: [<EMPTY> or "left of" or "right of" or "in front of" or "behind"] 
<EMPTY>: Just let it empty.  At least one should be not <EMPTY>

Ask me the next question after I answer you.
Ask me questions carefully  considering existence and uniqueness presupposition.
"""

In [None]:
BRAKET_ERROR = "The question should be in []"
FORMAT_ERROR = "Unsupported format or unsupported option. choose the most similar option, even if it’s not totally the same."
EXISTENCE_ERROR = "there is no "
UNIQUENESS_ERROR = "there are "
EMPTY_ERROR = "At least one should not be <EMPTY>"

SIZE = ["small", "tiny", "large", "big"]
COLOR = ["gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow"]
MATERAIL = ["rubber", "matte", "metal", "metallic", "shiny"]
SHAPE = ["cube", "block", "sphere", "ball", "cylinder", "object", "thing", "cubes", "blocks", "spheres", "balls", "cylinders", "objects", "things"]
RELATION = ["left", "right", "front", "behind"]
PREPOSITION = ["in", "of"]

REMOVE_S = {"cubes": "cube", "blocks": "block", "spheres": "sphere", "balls": "ball", "cylinders": "cylinder", "objects": "object", "things": "thing"}
SYNONYMS = {"tiny": "small", "big": "large", "matte": "rubber", "metallic": "metal", "shiny": "metal", "block": "cube", "ball": "sphere"}

ANSWER_OPTIONS = ["yes", "no", "small", "tiny", "large", "big", "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow", "rubber", "matte", "metal", "metallic", "shiny", "cube", "block", "sphere", "ball", "cylinder", "cubes", "blocks", "spheres", "balls", "cylinders"]

question_limit = 20

In [None]:
def post_process_answer(answer):
    answer = answer.lower().replace(".","")
    if answer in REMOVE_S.keys():
        answer = REMOVE_S[answer]
    if answer in SYNONYMS.keys():
        answer = SYNONYMS[answer]
    return answer

In [None]:
def ask_why(chatgpt):
    prompt = "Why?"
    prompt = "User: " + prompt
    response = chatgpt.prompt(prompt)
    return response

In [None]:
def guess_answer(chatgpt):
    prompt = """You have used all the chances to ask questions. Now, guess a answer anyway.
    You can choose the answer in these options ["yes", "no", "small", "tiny", "large", "big", "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow", "rubber", "matte", "metal", "metallic", "shiny", "cube", "block", "sphere", "ball", "cylinder", "cubes", "blocks", "spheres", "balls", "cylinders"] or a digital number
    Give me an answer within 1 word.
    The answer should be in {}
    """
    prompt = "User: " + prompt
    response = chatgpt.prompt(prompt)
    return response

In [None]:
def check_answer_format(answer, chatgpt):
    if answer not in ANSWER_OPTIONS and not answer.isdigit():
        prompt = """ You have to choose your answer in these options
        ["yes", "no", "small", "tiny", "large", "big", "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow", "rubber", "matte", "metal", "metallic", "shiny", "cube", "block", "sphere", "ball", "cylinder", "cubes", "blocks", "spheres", "balls", "cylinders"] or a digital number
        The answer should be in {}
        """
        prompt = "User: " + prompt
        response = chatgpt.prompt(prompt)
        abidx = response.find('{')
        aeidx = response.find('}')
        if abidx != -1 or aeidx != -1:
            response = response[abidx+1:aeidx]
        return False, response
    else:
        return True, None

In [None]:
def check_format(response):
    entities = response.lower().replace("?","").replace(" a "," ").replace(" the "," ").replace("is there","").replace("what size is","").replace("what color is","").replace("what material is","").replace("what shape is","").replace("how many","").replace("are there","").replace("<empty>","").strip().split(" ")
    entity = ""
    for e in entities:
        if len(e) == 0 or e in PREPOSITION: continue
        if (e not in SIZE) and (e not in COLOR) and (e not in MATERAIL) and (e not in SHAPE) and (e not in RELATION):
            return False, entity , e
        if e in REMOVE_S.keys():
            e = REMOVE_S[e]
        if e == "left" or e == "right":
            e += " of"
        if e == "front":
            e = "in front of"
        entity += e + " "
    entity = entity.strip()
    return True, entity, None

In [None]:
def check_existence(entity):
    ex_q = f"is there a {entity}?"
    ex_a = infer_mdetr(img, ex_q, model)
    if "yes" in ex_a:
        return True
    else:
        return False

In [None]:
def check_uniqueness(entity):
    uq_q = f"how many {entity} are there?"
    uq_a = infer_mdetr(img, uq_q, model)
    try:
        uq_a = int(uq_a)
        return uq_a
    except:
        return 0

In [None]:
def handling_response(response, hist, rep):
    abidx = response.find('{')
    aeidx = response.find('}')
    if abidx != -1 or aeidx != -1:
        a_gpt = response[abidx+1:aeidx]
        return True, a_gpt, "DONE!", hist
    
    qbidx = response.find('[')
    qeidx = response.find(']')
    if qbidx == -1 or qeidx == -1:
        return False, None, BRAKET_ERROR, hist
    q_gpt = response[qbidx+1:qeidx]
    q_gpt = q_gpt.lower().replace("<empty>","")
    q_gpt_nos = q_gpt.split()
    q_gpt_nos = [REMOVE_S[e] if e in REMOVE_S.keys() else e for e in q_gpt_nos]
    q_gpt_nos = ' '.join(q_gpt_nos)
    correct_format, entity, unsup_e = check_format(q_gpt_nos)
    if not correct_format:
        return False, None, unsup_e + ": " + FORMAT_ERROR, hist
    if len(entity) == 0:
        return False, None, EMPTY_ERROR, hist
    hist[f"E{str(rep)}"] = entity
    existence = check_existence(entity)
    if not existence:
        return False, None, EXISTENCE_ERROR + entity, hist
    uniqueness = check_uniqueness(entity)
    if uniqueness != 1:
        return False, None, UNIQUENESS_ERROR + str(uniqueness) + " " + entity + "s", hist
    
    hist[f"P{str(rep)}"] = q_gpt
    feedback = infer_mdetr(img, q_gpt, model)
    
    return False, None, feedback, hist
    

In [None]:
def chatgpt_dialogue(prompt):
    memo = {"history": []}
    openai_api_key = "<YOUR_OPENAI_API_KEY>"
    chatgpt = BasicChatGPT(openai_api_key)
    done = False
    rep = 0 
    while(not done and rep < question_limit):
        hist = {}
        prompt = "User: " + prompt
        response = chatgpt.prompt(prompt)
        hist[f"Q{str(rep)}"] = response
        done, answer, prompt, hist = handling_response(response, hist, rep)
        hist[f"A{str(rep)}"] = prompt
        memo["history"].append(hist)
        rep += 1
    if not done:
        response = guess_answer(chatgpt)
        done, answer, prompt, hist = handling_response(response, hist, rep)
        memo["guess"] = answer
    memo["raw_answer"] = answer
    processed_answer = post_process_answer(answer)
    memo["answer"] = processed_answer
    valid, re_answer = check_answer_format(processed_answer, chatgpt)
    reason = ask_why(chatgpt)
    memo["reason"] = reason
    if valid:
        return processed_answer, memo
    else:
        memo["raw_re_answer"] = re_answer
        processed_re_answer = post_process_answer(re_answer)
        memo["re_answer"] = processed_re_answer
        return processed_re_answer, memo

## Run CoQAH

In [None]:
import openai
import time
from sklearn.metrics import *
import json

save_file = 'coqah_20_clevr_human_val.json'
base_save_file = save_file.replace('coqah','base')

if os.path.isfile(save_file):
    res_dict = json.loads(open(save_file, 'r').read())
    base_dict = json.loads(open(base_save_file, 'r').read())
    starting_idx = res_dict["id"][-1]+1
else:
    res_dict = {"id": [], "res": [], "gts": [], "memo": []}
    base_dict = {"id": [], "res": [], "gts": []}
    starting_idx = 0
    
for idx in range(starting_idx, len(ann["questions"])):
    sample = ann["questions"][idx]
    file_name = sample["image_filename"]
    question = sample["question"]
    gt = sample["answer"]
    img = Image.open(os.path.join(image_path, file_name)).convert("RGB")

    qprompt = f"Your goal is to answer this question:\n{question}\n"
    sprompt = qprompt + STARTING_PROMPT
    try:
        answer, memo = chatgpt_dialogue(sprompt)
    except:
        print("wating for time limit")
        time.sleep(60)
        answer, memo = chatgpt_dialogue(sprompt)
    
    res_dict["id"].append(idx)
    res_dict["res"].append(answer)
    res_dict["gts"].append(gt)
    res_dict["memo"].append(memo)
    acc = accuracy_score(res_dict["gts"], res_dict["res"])
    with open(save_file, 'w') as f:
        json.dump(res_dict, f)
    print(idx)
    print("gt: ", gt)
    print("pred: ", answer)
    print("AFA ACC: ",acc)
    baseline_answer = infer_mdetr(img, question, model)
    base_dict["id"].append(idx)
    base_dict["res"].append(baseline_answer)
    base_dict["gts"].append(gt)
    base_acc = accuracy_score(base_dict["gts"], base_dict["res"])
    with open(base_save_file, 'w') as f:
        json.dump(base_dict, f)
    print("Baseline: ", baseline_answer)
    print("Baseline ACC: ", base_acc)
    time.sleep(1)