In [None]:
import json
import numpy as np
import random

random.seed(1234)


# ----------- Process Multi-choice -------------
def parse_multi_choice_response(response, all_choices, index2ans):
    """
    Parse the prediction from the generated response.
    Return the predicted index e.g., A, B, C, D.
    """
    for char in [",", ".", "!", "?", ";", ":", "'"]:
        response = response.strip(char)
    response = " " + response + " "  # add space to avoid partial match

    index_ans = True
    ans_with_brack = False
    candidates = []
    for choice in all_choices:  # e.g., (A) (B) (C) (D)
        if f"({choice})" in response:
            candidates.append(choice)
            ans_with_brack = True

    if len(candidates) == 0:
        for choice in all_choices:  # e.g., A B C D
            if f" {choice} " in response:
                candidates.append(choice)

    # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
    if len(candidates) == 0 and len(response.split()) > 5:
        for index, ans in index2ans.items():
            if ans.lower() in response.lower():
                candidates.append(index)
                index_ans = False  # it's content ans.

    if len(candidates) == 0:  # still not get answer, randomly choose one.
        pred_index = random.choice(all_choices)
    elif len(candidates) > 1:
        start_indexes = []
        if index_ans:
            if ans_with_brack:
                for can in candidates:
                    index = response.rfind(f"({can})")
                    start_indexes.append(index)  # -1 will be ignored anyway
                # start_indexes = [generated_response.index(f'({can})') for can in candidates]
            else:
                for can in candidates:
                    index = response.rfind(f" {can} ")
                    start_indexes.append(index)
        else:
            for can in candidates:
                index = response.lower().rfind(index2ans[can].lower())
                start_indexes.append(index)
        # get the last one
        pred_index = candidates[np.argmax(start_indexes)]
    else:  # if only one candidate, use it.
        pred_index = candidates[0]

    return pred_index


def eval(filename):
    all_accs = []
    for filename in filenames:
        data = [json.loads(line) for line in open(filename)]

        accs = []
        random_cnt = 0
        for item in data:

            label = item["question_id"].split("_")[1]
            gen = item["text"]
            all_choices = ["A", "B", "C", "D"]
            index2ans = {}

            choices = item["prompt"].split("\n")[1:5]
            choices = [choice[3:] for choice in choices]
            for idx, choice in enumerate(choices):
                index2ans[all_choices[idx]] = choice

            # print(index2ans)
            pred = parse_multi_choice_response(gen, all_choices, index2ans)

            acc = pred == label
            accs.append(acc)

        all_accs.append(accs)

        print(filename, len(accs), random_cnt, np.mean(accs))
    return all_accs

In [None]:
filenames = ["imagewikiqa-predictions_llava-7b_imagenet-and-llava-trained.jsonl"]
all_accs = eval(filenames)