In [1]:
import json
import os
import pandas as pd
import numpy as np
import jsonlines

In [6]:
def get_results(factfile):
        gts = []
        preds = []
        records = []
        with jsonlines.open(factfile) as reader:
            for line in reader:
                if isinstance(line['prediction'], str):
                    p = line['prediction'].lower().strip()
                else:
                    p = [pred.lower().strip() for pred in line['prediction']]
                if isinstance(line['correct'], str):
                    c = line['correct'].lower().strip()
                else:
                    c = [corr.lower().strip() for corr in line['correct']]
                preds.append(p)
                gts.append(c)
                records.append(line)
        return preds,gts,records
    
def accuracy(preds, gts):
    ncorrect = 0
    for p,g in zip(preds, gts):
        if isinstance(p, str):
            if p == g:
                ncorrect+=1
        else:
            if g in p:
                ncorrect+=1
    return ncorrect/len(preds) * 100

In [7]:
def get_accuracy_from_results(factfile):
    preds,gts,_ = get_results(factfile)
    return accuracy(preds, gts)

def get_incorrect_preds(factfile):
    preds, gts, records = get_results(factfile)
    incorrect_records = []
    for record_id, (pred,gt,record) in enumerate(zip(preds, gts, records)):
        if pred != gt:
            incorrect_records.append((record_id, record))
    return incorrect_records

def get_all_records(factfile):
    _, _, records = get_results(factfile)
    return records

In [5]:
models = ["mistral", "gemma", "phi", "falcon", "llama"]

In [6]:
def analyze_accuracy_in_topk(modelname, k=5):
    acc = get_accuracy_from_results(f"predictions_{modelname}_top{k}.jsonl")
    return acc

def analyze_accuracy(filename):
    acc = get_accuracy_from_results(filename)
    return acc

In [11]:
analyze_accuracy("results_instructed_augmented_llama.jsonl")

91.52542372881356

In [12]:
analyze_accuracy("results_augmented_llama.jsonl")

79.66101694915254

In [None]:
analyze_accuracy("results_instructed_augmented_llama.jsonl")

In [8]:
analyze_accuracy('results_augmented_mistral.jsonl')

93.02325581395348

In [9]:
analyze_accuracy('results_instructed_augmented_mistral.jsonl')

97.67441860465115

In [54]:
for m in models:
    print(f"{m}  {analyze_accuracy_in_topk(m)}")

mistral  95.33622559652929
gemma  93.87417218543047
phi  88.17204301075269
falcon  93.255620316403
llama  93.45898004434589


In [2]:
def get_augmented_records(filename):
    with jsonlines.open(filename) as reader:
        augmented_records = []
        for i,m in enumerate(reader):
            augmented_records.append(m)

            # refer to original dataset with id available
            idx = m['known_id']
            subject = original_dataset[idx]['subject']
            attribute = original_dataset[idx]['attribute']
            prompt = augmented_records[i]['prompt']

            augmented_records[i]['subject'] = f"{attribute}"
            augmented_records[i]['prompt'] = f"{prompt} {attribute}. {prompt}"
        return augmented_records


def dump_to_file(records, modelname, filename):
    with jsonlines.open(f"{filename}_{modelname}.jsonl", 'w') as writer:
        for r in records:
            writer.write(r)

In [4]:
original_dataset = json.load(open("../notebooks/data/known_1000.json"))
def get_incorrectly_done_records(modelname):
    preds, gts, records = get_results(f"predictions_{modelname}_top5.jsonl")
    incorrect_records = []
    for record_id, (pred,gt,record) in enumerate(zip(preds, gts, records)):
        if gt not in pred:  
            record['subject'] = original_dataset[record['known_id']]['subject']
            record['clean_subject'] = record['subject']
            record['subject'] = ''.join(record['subject'].split())
            incorrect_records.append(record)
    return incorrect_records

In [8]:
mistral_incorrect = get_incorrectly_done_records('mistral')

In [11]:
llama_incorrect = get_incorrectly_done_records('llama')

In [10]:
original_dataset = json.load(open("../notebooks/data/known_1000.json"))

In [19]:
len(original_dataset)

1209

In [9]:
dump_to_file(mistral_incorrect, 'mistral', 'incorrect')

In [12]:
dump_to_file(llama_incorrect, 'llama', 'incorrect')

In [13]:
with jsonlines.open("augmented_llama.jsonl", "w") as writer:
    with jsonlines.open("incorrect_llama.jsonl") as reader:
        for line in reader:
            line['prompt'] = f"{line['prompt']} {line['correct']}. {line['prompt']}"
            line['subject'] = line['correct']
            writer.write(line)


In [10]:
with jsonlines.open("augmented_mistral.jsonl", "w") as writer:
    with jsonlines.open("incorrect_mistral.jsonl") as reader:
        for line in reader:
            line['prompt'] = f"{line['prompt']} {line['correct']}. {line['prompt']}"
            line['subject'] = line['correct']
            writer.write(line)


In [7]:
import random
with jsonlines.open("incorrect_multidoc_mistral.jsonl", "w") as writer:
    with jsonlines.open("incorrect_mistral.jsonl") as reader:
        all_prompts = []
        for line in reader:
            all_prompts.append(line)
        for i, line in enumerate(all_prompts):
            current_context = f"{line['prompt']} {line['correct']}"
            random_context_1, random_context_2, random_context_3 = tuple(random.sample(all_prompts[:i]+all_prompts[i+1:], 3))
            random_context_1 = f"{random_context_1['prompt']} {random_context_1['correct']}"
            random_context_2 = f"{random_context_2['prompt']} {random_context_2['correct']}"
            random_context_3 = f"{random_context_3['prompt']} {random_context_3['correct']}"
            full_context = [current_context, random_context_1, random_context_2, random_context_3]
            random.shuffle(full_context)
            full_context = "\n".join(full_context)
            line['prompt'] = f"Answer based on the context. Context: {full_context}" + \
                                "\n" + \
                                f"Question: {line['prompt']}"
            line['subject'] = line['correct']
            writer.write(line)

In [None]:
import random
with jsonlines.open("incorrect_multidoc_mistral.jsonl", "w") as writer:
    with jsonlines.open("incorrect_mistral.jsonl") as reader:
        all_prompts = []
        for line in reader:
            all_prompts.append(line)
        for i, line in enumerate(all_prompts):
            current_context = f"{line['prompt']} {line['correct']}"
            random_context_1, random_context_2, random_context_3 = tuple(random.sample(all_prompts[:i]+all_prompts[i+1:], 3))
            random_context_1 = f"{random_context_1['prompt']} {random_context_1['correct']}"
            random_context_2 = f"{random_context_2['prompt']} {random_context_2['correct']}"
            random_context_3 = f"{random_context_3['prompt']} {random_context_3['correct']}"
            full_context = [current_context, random_context_1, random_context_2, random_context_3]
            random.shuffle(full_context)
            full_context = "\n".join(full_context)
            line['prompt'] = f"Answer based on the context. Context: {full_context}" + \
                                "\n" + \
                                f"Question: {line['prompt']}"
            line['subject'] = line['correct']
            writer.write(line)