In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import cv2
import numpy
from matplotlib import pyplot as plt
import glob
import json
import os
from mpl_toolkits.axes_grid1 import ImageGrid
import csv
import json
import os
from PIL import Image
from io import BytesIO
import base64
import pandas as pd
import torch
import torch.nn.functional as F
import numpy as np
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from tasks.mm_tasks.refcoco import RefcocoTask
from models.ofa import OFAModel
from PIL import Image, ImageOps

# Load VQA model

In [None]:
tasks.register_task('refcoco', RefcocoTask)

use_cuda = torch.cuda.is_available()
use_fp16 = False

model_path = 'OFA/checkpoint_best.pt'
parser = options.get_generation_parser()
input_args = ["", "--task=refcoco", "--beam=10", f"--path={model_path}", "--bpe-dir=utils/BPE", "--no-repeat-ngram-size=3", "--patch-image-size=384"]
args = options.parse_args_and_arch(parser, input_args)
cfg = convert_namespace_to_omegaconf(args)

# Load pretrained ckpt & config
task = tasks.setup_task(cfg.task)
vqa_models, cfg = checkpoint_utils.load_model_ensemble(
    utils.split_paths(cfg.common_eval.path),
    task=task
)
# Move models to GPU
for model in vqa_models:
    model.eval()
    if use_fp16:
        model.half()
    if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
        model.cuda()
    model.prepare_for_inference_(cfg)

# Initialize generator
vqa_generator = task.build_generator(vqa_models, cfg.generation)

In [None]:
# Image transform
from torchvision import transforms
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

patch_resize_transform = transforms.Compose([
    lambda image: image.convert("RGB"),
    transforms.Resize((task.cfg.patch_image_size, task.cfg.patch_image_size), interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
    transforms.RandomHorizontalFlip(p=0)
])

# Text preprocess
bos_item = torch.LongTensor([task.src_dict.bos()])
eos_item = torch.LongTensor([task.src_dict.eos()])
pad_idx = task.src_dict.pad()


def get_symbols_to_strip_from_output(generator):
    if hasattr(generator, "symbols_to_strip_from_output"):
        return generator.symbols_to_strip_from_output
    else:
        return {generator.bos, generator.eos}


def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
    x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
    token_result = []
    bin_result = []
    img_result = []
    for token in x.strip().split():
        if token.startswith('<bin_'):
            bin_result.append(token)
        elif token.startswith('<code_'):
            img_result.append(token)
        else:
            if bpe is not None:
                token = bpe.decode('{}'.format(token))
            if tokenizer is not None:
                token = tokenizer.decode(token)
            if token.startswith(' ') or len(token_result) == 0:
                token_result.append(token.strip())
            else:
                token_result[-1] += token

    return ' '.join(token_result), ' '.join(bin_result), ' '.join(img_result)


def coord2bin(coords, w_resize_ratio, h_resize_ratio):
    coord_list = [float(coord) for coord in coords.strip().split()]
    bin_list = []
    bin_list += ["<bin_{}>".format(int(round(coord_list[0] * w_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))]
    bin_list += ["<bin_{}>".format(int(round(coord_list[1] * h_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))]
    bin_list += ["<bin_{}>".format(int(round(coord_list[2] * w_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))]
    bin_list += ["<bin_{}>".format(int(round(coord_list[3] * h_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))]
    return ' '.join(bin_list)


def bin2coord(bins, w_resize_ratio, h_resize_ratio):
    bin_list = [int(bin[5:-1]) for bin in bins.strip().split()]
    coord_list = []
    coord_list += [bin_list[0] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / w_resize_ratio]
    coord_list += [bin_list[1] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / h_resize_ratio]
    coord_list += [bin_list[2] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / w_resize_ratio]
    coord_list += [bin_list[3] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / h_resize_ratio]
    return coord_list


def encode_text(text, length=None, append_bos=False, append_eos=False):
    line = [
      task.bpe.encode(' {}'.format(word.strip())) 
      if not word.startswith('<code_') and not word.startswith('<bin_') else word
      for word in text.strip().split()
    ]
    line = ' '.join(line)
    s = task.tgt_dict.encode_line(
        line=line,
        add_if_not_exist=False,
        append_eos=False
    ).long()
    if length is not None:
        s = s[:length]
    if append_bos:
        s = torch.cat([bos_item, s])
    if append_eos:
        s = torch.cat([s, eos_item])
    return s

def construct_sample(image: Image, instruction: str, image2: Image = None):
    if image2 is None:
        patch_image2 = None
    else:
        patch_image2 = patch_resize_transform(image2).unsqueeze(0)
    
    patch_image = patch_resize_transform(image).unsqueeze(0)
    patch_mask = torch.tensor([True])

    instruction = encode_text(' {}'.format(instruction.lower().strip()), append_bos=True, append_eos=True).unsqueeze(0)
    instruction_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in instruction])
    sample = {
        "id":np.array(['42']),
        "net_input": {
            "src_tokens": instruction,
            "src_lengths": instruction_length,
            "patch_images": patch_image,
            "patch_images_2": patch_image2,
            "patch_masks": patch_mask,
        }
    }
    return sample

def construct_sample_wo_image(instruction: str):
    patch_mask = torch.tensor([True])
    instruction = encode_text(' {}'.format(instruction.lower().strip()), append_bos=True, append_eos=True).unsqueeze(0)
    instruction_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in instruction])
    sample = {
        "id":np.array(['42']),
        "net_input": {
            "src_tokens": instruction,
            "src_lengths": instruction_length,
            "patch_masks": patch_mask,
        }
    }
    return sample

# Function to turn FP32 to FP16
def apply_half(t):
    if t.dtype is torch.float32:
        return t.to(dtype=torch.half)
    return t

def infer_ofa(img, question):
    sample = construct_sample(img, question)
    sample = utils.move_to_cuda(sample) if use_cuda else sample
    sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample

    with torch.no_grad():
        hypos = task.inference_step(vqa_generator, vqa_models, sample)
        tokens, bins, imgs = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, vqa_generator)
    return tokens

## Setting for ChatGPT

In [None]:
# Basic ChatGPT class
import openai

class BasicChatGPT:
    def __init__(self, openai_api_key, max_tokens=100):
        openai.api_key = openai_api_key
        self.max_tokens = max_tokens
        self.messages = []

    def call(self):
        response = openai.ChatCompletion.create(
          model="gpt-4",
          messages=self.messages,
          max_tokens=self.max_tokens,
          seed = 123,
        )
        return response

    def prompt(self, text):
        # Update the messages so far
        self.messages.append({
            "role": "user",
            "content": text,
        })

        # Call ChatGPT
        response = self.call()

        # Save the returned message
        message = response["choices"][0]["message"]
        self.messages.append(message)

        return message["content"]

## Functions for CoQAH

In [None]:
STARTING_PROMPT = """ 
But you can't access the image and I can access the image.
You can ask me questions with these forms.
The question should be in []
If you can answer the above question, stop asking and give me an answer within 1 word.
The answer should be in {}
[is there a <ENTITY> ?]
[what abnormalities are seen in this image?]
[where is the <ENTITY> ?]
[what level is the <ENTITY> ?]
[what type is the <ENTITY> ?]
[which view is this image taken?]

<ENTITY> :  [pleural effusion or atelectasis or cardiomegaly or enlargement of the cardiac silhouette or edema or hernia or vascular congestion or hilar congestion or pneumothorax or heart failure or lung opacity or pneumonia or tortuosity of the descending aorta or scoliosis or gastric distention or hypoxemia or hypertensive heart disease or hematoma or tortuosity of the thoracic aorta or contusion or emphysema or granuloma or calcification or pleural thickening or thymoma or blunting of the costophrenic angle or consolidation or fracture or pneumomediastinum or air collection]

Ask me the next question after I answer you.
Ask me questions carefully considering existence presupposition.
"""

In [None]:
BRAKET_ERROR = "The question should be in []"
FORMAT_ERROR = "Unsupported format or unsupported option. choose the most similar option, even if it’s not totally the same."
EXISTENCE_ERROR = "there is no "

ENTITY = ["pleural effusion", "atelectasis", "cardiomegaly", "enlargement of cardiac silhouette", "edema", "hernia", "vascular congestion", "hilar congestion", "pneumothorax", "heart failure", "lung opacity", "pneumonia", "tortuosity of descending aorta", "scoliosis", "gastric distention", "hypoxemia", "hypertensive heart disease", "hematoma", "tortuosity of thoracic aorta", "contusion", "emphysema", "granuloma", "calcification", "pleural thickening", "thymoma", "blunting of costophrenic angle", "consolidation", "fracture", "pneumomediastinum", "air collection"]

question_limit = 5

In [None]:
def ask_why(chatgpt):
    prompt = "Why?"
    prompt = "User: " + prompt
    response = chatgpt.prompt(prompt)
    return response

def guess_answer(chatgpt):
    prompt = """You have used all the chances to ask questions. Now, guess a answer anyway.
    Give me an answer within 1 word.
    The answer should be in {}
    """
    prompt = "User: " + prompt
    response = chatgpt.prompt(prompt)
    return response

def post_process_answer(answer):
    answer = answer.lower().replace(".","")
    return answer

def check_answer_format(answer, chatgpt): 
    if answer == 'unknown':
        prompt = """ You can't say 'unknown'. guess an answer anyway.
        The answer should be in {}
        """
        print(prompt)
        prompt = "User: " + prompt
        response = chatgpt.prompt(prompt)
        abidx = response.find('{')
        aeidx = response.find('}')
        if abidx != -1 or aeidx != -1:
            response = response[abidx+1:aeidx]
        return False, response
    
    return True, None

def check_format(response):
    entities = response.lower().replace("?","").replace(" a "," ").replace(" an "," ").replace(" the "," ").replace("is there","").replace("where is","").replace("what level is","").replace("what type is","").strip()
    if "what abnormalities are seen in this image" in entities or "which view is this image taken" in entities:
        return True, "", None
    if (entities not in ENTITY):
        return False, entities , entities
    return True, entities, None

def check_existence(entity):
    ex_q = f"is there a {entity}?"
    ex_a = infer_ofa(img, ex_q)
    if "yes" in ex_a:
        return True
    else:
        return False
    
def handling_response(response, hist, rep):
    abidx = response.find('{')
    aeidx = response.find('}')
    if abidx != -1 or aeidx != -1:
        a_gpt = response[abidx+1:aeidx]
        return True, a_gpt, "DONE!", hist
    
    qbidx = response.find('[')
    qeidx = response.find(']')
    if qbidx == -1 or qeidx == -1:
        return False, None, BRAKET_ERROR, hist
    q_gpt = response[qbidx+1:qeidx]
    q_gpt = q_gpt.lower()
    correct_format, entity, unsup_e = check_format(q_gpt)
    if not correct_format:
        return False, None, unsup_e + ": " + FORMAT_ERROR, hist
    hist[f"E{str(rep)}"] = entity
    if len(entity) > 0:
        existence = check_existence(entity)
        if not existence:
            return False, None, EXISTENCE_ERROR + entity, hist

    hist[f"P{str(rep)}"] = q_gpt
    feedback = infer_ofa(img, q_gpt)
    
    return False, None, feedback, hist
    
def chatgpt_dialogue(prompt):
    memo = {"history": []}
    openai_api_key = "<YOUR_OPENAI_API_KEY>
    chatgpt = BasicChatGPT(openai_api_key)
    done = False
    rep = 0 
    while(not done and rep < question_limit):
        hist = {}
        prompt = "User: " + prompt
        response = chatgpt.prompt(prompt)
        hist[f"Q{str(rep)}"] = response
        done, answer, prompt, hist = handling_response(response, hist, rep)
        hist[f"A{str(rep)}"] = prompt
        memo["history"].append(hist)
        rep += 1
    if not done:
        response = guess_answer(chatgpt)
        done, answer, prompt, hist = handling_response(response, hist, rep)
        memo["guess"] = answer
    memo["raw_answer"] = answer
    processed_answer = post_process_answer(answer)
    memo["answer"] = processed_answer
    valid, re_answer = check_answer_format(processed_answer, chatgpt)
    reason = ask_why(chatgpt)
    memo["reason"] = reason
    if valid:
        return processed_answer, memo
    else:
        memo["raw_re_answer"] = re_answer
        processed_re_answer = post_process_answer(re_answer)
        memo["re_answer"] = processed_re_answer
        return processed_re_answer, memo

## Load Dataset

In [None]:
test_set = 'slake_all_closed.tsv'
fp = open(os.path.join('vqa_data/', test_set), "r")
fp.seek(0)
images = []
gts = []
questions = []
images2 = []
fp.seek(0)
while(True):
    column_l = fp.readline().rstrip("\n").split("\t")
    if len(column_l) == 1: break
    images.append(column_l[1])
    gts.append(column_l[2])
    questions.append(column_l[3])
    images2.append(column_l[4])

## Run CoQAH

In [None]:
import openai
import time
from sklearn.metrics import *
import json

save_file = 'coqah_5_slake_all_closed.json'
base_save_file = save_file.replace('coqah','base')

if os.path.isfile(save_file):
    res_dict = json.loads(open(save_file, 'r').read())
    base_dict = json.loads(open(base_save_file, 'r').read())
    starting_idx = res_dict["id"][-1]+1
else:
    res_dict = {"id": [], "res": [], "gts": [], "memo": []}
    base_dict = {"id": [], "res": [], "gts": []}
    starting_idx = 0
    
for idx in range(starting_idx, len(gts)):
    img = Image.open(BytesIO(base64.urlsafe_b64decode(images[idx])))
    gt = gts[idx]
    question = questions[idx]
     
    qprompt = f"Your goal is to answer this question:\n{question}\n"
    sprompt = qprompt + STARTING_PROMPT
    try:
        answer, memo = chatgpt_dialogue(sprompt)
    except:
        print("wating for time limit")
        time.sleep(60)
        answer, memo = chatgpt_dialogue(sprompt)
    
    res_dict["id"].append(idx)
    res_dict["res"].append(answer)
    res_dict["gts"].append(gt)
    res_dict["memo"].append(memo)
    acc = accuracy_score(res_dict["gts"], res_dict["res"])
    with open(save_file, 'w') as f:
        json.dump(res_dict, f)
    print(idx)
    print(question)
    print("gt: ", gt)
    print("pred: ", answer)
    print("AFA ACC: ",acc)
    # baseline_answer = infer_ofa(img, question)
    # base_dict["id"].append(idx)
    # base_dict["res"].append(baseline_answer)
    # base_dict["gts"].append(gt)
    # base_acc = accuracy_score(base_dict["gts"], base_dict["res"])
    # with open(base_save_file, 'w') as f:
    #     json.dump(base_dict, f)
    # print("Baseline: ", baseline_answer)
    # print("Baseline ACC: ", base_acc)
    time.sleep(1)