In [1]:
import json
import string
import regex
from typing import List

In [2]:
def normalize_answer(s: str) -> str:
    """Normalization from the SQuAD evaluation script.

    See https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
    """
    def remove_articles(text):
        return regex.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def is_correct(prediction: str, ground_truths: List[str]) -> float:
    normalized_prediction = normalize_answer(prediction)

    for ground_truth in ground_truths:
        normalized_ground_truth = normalize_answer(ground_truth)
        if normalized_ground_truth.lower() in normalized_prediction.lower():
            return 1.0
    return 0.0

In [3]:
def evaluate_qa_responses(inp : str):
    correct_count, total_count = 0, 0
    with open(inp) as fin:
        for line in fin:
            input_res = json.loads(line)
            total_count += 1
            
            correct_answers = input_res["correct_answer"]
            model_answer = input_res["model_answer"]       
            correct_count += is_correct(model_answer, correct_answers)
            
    return correct_count / total_count

## Gemini Pro

In [4]:
print("Accuracies for qa task closedbook =")
print(evaluate_qa_responses("./responses/gemini_qa/gemini_closedbook_responses.jsonl"))
print("Accuracies for qa task oracle =")
print(evaluate_qa_responses("./responses/gemini_qa/gemini_oracle_responses.jsonl"))
print("Accuracies for qa task (QAC) oracle =")
print(evaluate_qa_responses("./responses/gemini_qa/gemini_oracle_QAC_responses.jsonl"))

Accuracies for qa task closedbook =
0.43615819209039547
Accuracies for qa task oracle =
0.7156308851224106
Accuracies for qa task (QAC) oracle =
0.7830508474576271


In [5]:
input_paths = ["./responses/gemini_qa/gemini_10_doc_at_0_responses.jsonl",
               "./responses/gemini_qa/gemini_10_doc_at_4_responses.jsonl",
               "./responses/gemini_qa/gemini_10_doc_at_9_responses.jsonl"]

loc = [0, 4, 9]

In [6]:
print("Accuracies for qa task with number of documents = 10 and -")
for i in range(len(input_paths)):
    print("relevant document located at", loc[i], "is =", evaluate_qa_responses(input_paths[i]))

Accuracies for qa task with number of documents = 10 and -
relevant document located at 0 is = 0.6263653483992467
relevant document located at 4 is = 0.5721280602636535
relevant document located at 9 is = 0.6598870056497175


In [7]:
input_paths = ["./responses/gemini_qa/gemini_10_doc_at_0_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_10_doc_at_4_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_10_doc_at_9_QAC_responses.jsonl"]

loc = [0, 4, 9]

In [8]:
print("Accuracies for qa task (QAC) with number of documents = 10 and -")
for i in range(len(input_paths)):
    print("relevant document located at", loc[i], "is =", evaluate_qa_responses(input_paths[i]))

Accuracies for qa task (QAC) with number of documents = 10 and -
relevant document located at 0 is = 0.6738229755178907
relevant document located at 4 is = 0.6497175141242938
relevant document located at 9 is = 0.6922787193973635


In [9]:
input_paths = ["./responses/gemini_qa/gemini_20_doc_at_0_responses.jsonl",
               "./responses/gemini_qa/gemini_20_doc_at_4_responses.jsonl",
               "./responses/gemini_qa/gemini_20_doc_at_9_responses.jsonl",
               "./responses/gemini_qa/gemini_20_doc_at_14_responses.jsonl",
               "./responses/gemini_qa/gemini_20_doc_at_19_responses.jsonl"]

loc = [0, 4, 9, 14, 19]

In [10]:
print("Accuracies for qa task with number of documents = 20 and -")
for i in range(len(input_paths)):
    print("relevant document located at", loc[i], "is =", evaluate_qa_responses(input_paths[i]))

Accuracies for qa task with number of documents = 20 and -
relevant document located at 0 is = 0.5856873822975518
relevant document located at 4 is = 0.5318267419962335
relevant document located at 9 is = 0.5480225988700564
relevant document located at 14 is = 0.5555555555555556
relevant document located at 19 is = 0.6444444444444445


In [11]:
input_paths = ["./responses/gemini_qa/gemini_20_doc_at_0_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_20_doc_at_4_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_20_doc_at_9_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_20_doc_at_14_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_20_doc_at_19_QAC_responses.jsonl"]

loc = [0, 4, 9, 14, 19]

In [12]:
print("Accuracies for qa task (QAC) with number of documents = 20 and -")
for i in range(len(input_paths)):
    print("relevant document located at", loc[i], "is =", evaluate_qa_responses(input_paths[i]))

Accuracies for qa task (QAC) with number of documents = 20 and -
relevant document located at 0 is = 0.632768361581921
relevant document located at 4 is = 0.599623352165725
relevant document located at 9 is = 0.6203389830508474
relevant document located at 14 is = 0.6222222222222222
relevant document located at 19 is = 0.6734463276836158


In [13]:
input_paths = ["./responses/gemini_qa/gemini_30_doc_at_0_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_4_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_9_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_14_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_19_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_24_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_29_responses.jsonl"]

loc = [0, 4, 9, 14, 19, 24, 29]

In [14]:
print("Accuracies for qa task with number of documents = 30 and -")
for i in range(len(input_paths)):
    print("relevant document located at", loc[i], "is =", evaluate_qa_responses(input_paths[i]))

Accuracies for qa task with number of documents = 30 and -
relevant document located at 0 is = 0.5792843691148776
relevant document located at 4 is = 0.4463276836158192
relevant document located at 9 is = 0.45649717514124294
relevant document located at 14 is = 0.4873822975517891
relevant document located at 19 is = 0.5133709981167608
relevant document located at 24 is = 0.5099811676082863
relevant document located at 29 is = 0.6376647834274953


In [15]:
input_paths = ["./responses/gemini_qa/gemini_30_doc_at_0_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_4_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_9_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_14_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_19_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_24_QAC_responses.jsonl",
               "./responses/gemini_qa/gemini_30_doc_at_29_QAC_responses.jsonl"]

loc = [0, 4, 9, 14, 19, 24, 29]

In [16]:
print("Accuracies for qa task (QAC) with number of documents = 30 and -")
for i in range(len(input_paths)):
    print("relevant document located at", loc[i], "is =", evaluate_qa_responses(input_paths[i]))

Accuracies for qa task (QAC) with number of documents = 30 and -
relevant document located at 0 is = 0.6369114877589453
relevant document located at 4 is = 0.5408662900188324
relevant document located at 9 is = 0.543879472693032
relevant document located at 14 is = 0.5698681732580038
relevant document located at 19 is = 0.5947269303201507
relevant document located at 24 is = 0.6015065913370998
relevant document located at 29 is = 0.6696798493408663


## RWKV - Accuracy vs Size

In [17]:
input_paths = [
    "./responses/rwkv_accuracy_vs_size/raven_3b_20_doc_at_0_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_3b_20_doc_at_4_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_3b_20_doc_at_9_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_3b_20_doc_at_14_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_3b_20_doc_at_19_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_7b_20_doc_at_0_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_7b_20_doc_at_4_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_7b_20_doc_at_9_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_7b_20_doc_at_14_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_7b_20_doc_at_19_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_14b_20_doc_at_0_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_14b_20_doc_at_4_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_14b_20_doc_at_9_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_14b_20_doc_at_14_responses.jsonl",
    "./responses/rwkv_accuracy_vs_size/raven_14b_20_doc_at_19_responses.jsonl",
]

In [18]:
for i in range(len(input_paths)):
    print("Accuracy is =", evaluate_qa_responses(input_paths[i]))

Accuracy is = 0.19171374764595103
Accuracy is = 0.1856873822975518
Accuracy is = 0.1928436911487759
Accuracy is = 0.22259887005649717
Accuracy is = 0.4177024482109228
Accuracy is = 0.23691148775894538
Accuracy is = 0.24143126177024482
Accuracy is = 0.24444444444444444
Accuracy is = 0.271939736346516
Accuracy is = 0.3951035781544256
Accuracy is = 0.32354048964218457
Accuracy is = 0.3103578154425612
Accuracy is = 0.3009416195856874
Accuracy is = 0.3152542372881356
Accuracy is = 0.4335216572504708


## RWKV - Raven 14b QA Task Accuracy

In [19]:
input_paths = [
    "raven_14b_closedbook_responses",
    "raven_14b_oracle_responses",
    "raven_14b_oracle_QAC_responses",
    "raven_14b_10_doc_at_0_responses",
    "raven_14b_10_doc_at_4_responses",
    "raven_14b_10_doc_at_9_responses",
    "raven_14b_10_doc_at_0_QAC_responses",
    "raven_14b_10_doc_at_4_QAC_responses",
    "raven_14b_10_doc_at_9_QAC_responses",
    "raven_14b_20_doc_at_0_responses",
    "raven_14b_20_doc_at_4_responses",
    "raven_14b_20_doc_at_9_responses",
    "raven_14b_20_doc_at_14_responses",
    "raven_14b_20_doc_at_19_responses",
    "raven_14b_20_doc_at_0_QAC_responses",
    "raven_14b_20_doc_at_4_QAC_responses",
    "raven_14b_20_doc_at_9_QAC_responses",
    "raven_14b_20_doc_at_14_QAC_responses",
    "raven_14b_20_doc_at_19_QAC_responses",
    "raven_14b_30_doc_at_0_responses",
    "raven_14b_30_doc_at_4_responses",
    "raven_14b_30_doc_at_9_responses",
    "raven_14b_30_doc_at_14_responses",
    "raven_14b_30_doc_at_19_responses",
    "raven_14b_30_doc_at_24_responses",
    "raven_14b_30_doc_at_29_responses",
    "raven_14b_30_doc_at_0_QAC_responses",
    "raven_14b_30_doc_at_4_QAC_responses",
    "raven_14b_30_doc_at_9_QAC_responses",
    "raven_14b_30_doc_at_14_QAC_responses",
    "raven_14b_30_doc_at_19_QAC_responses",
    "raven_14b_30_doc_at_24_QAC_responses",
#     "raven_14b_30_doc_at_29_QAC_responses",
]

In [20]:
for i in range(len(input_paths)):
    print("Accuracy for", input_paths[i], "is =", evaluate_qa_responses("./responses/rwkv_raven_14b_qa/"+input_paths[i]+".jsonl"))

Accuracy for raven_14b_closedbook_responses is = 0.30772128060263654
Accuracy for raven_14b_oracle_responses is = 0.8056497175141243
Accuracy for raven_14b_oracle_QAC_responses is = 0.7619585687382298
Accuracy for raven_14b_10_doc_at_0_responses is = 0.3845574387947269
Accuracy for raven_14b_10_doc_at_4_responses is = 0.3495291902071563
Accuracy for raven_14b_10_doc_at_9_responses is = 0.4384180790960452
Accuracy for raven_14b_10_doc_at_0_QAC_responses is = 0.38342749529190207
Accuracy for raven_14b_10_doc_at_4_QAC_responses is = 0.335969868173258
Accuracy for raven_14b_10_doc_at_9_QAC_responses is = 0.40941619585687383
Accuracy for raven_14b_20_doc_at_0_responses is = 0.32354048964218457
Accuracy for raven_14b_20_doc_at_4_responses is = 0.3103578154425612
Accuracy for raven_14b_20_doc_at_9_responses is = 0.3009416195856874
Accuracy for raven_14b_20_doc_at_14_responses is = 0.3152542372881356
Accuracy for raven_14b_20_doc_at_19_responses is = 0.4335216572504708
Accuracy for raven_14b_2