In [1]:
# %% Imports

import re
import random
from itertools import filterfalse
from collections import Counter
import evaluate
import ollama
import pandas as pd
from tqdm import tqdm

random.seed(1718308331)

In [2]:
# %% Metrics

exact_match = evaluate.load("exact_match")


def calculate_metrics(prediction, ground_truth):
    prediction_tokens = re.findall(r"\w+", prediction.lower())
    ground_truth_tokens = re.findall(r"\w+", ground_truth.lower())

    common_tokens = Counter(prediction_tokens) & Counter(ground_truth_tokens)

    num_common_tokens = sum(common_tokens.values())

    if len(prediction_tokens) == 0:
        precision = 0.0
    else:
        precision = num_common_tokens / len(prediction_tokens)

    if len(ground_truth_tokens) == 0:
        recall = 0.0
    else:
        recall = num_common_tokens / len(ground_truth_tokens)

    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * (precision * recall) / (precision + recall)

    return precision, recall, f1

In [3]:
# %% Prepare llama3 with temperature=0.5

modelfile = """
FROM llama3
PARAMETER temperature 0.5
PARAMETER seed 1718308331
"""

ollama.create(model="llama3-temp0.5", modelfile=modelfile)

{'status': 'success'}

In [4]:
# %% Load test data

df = pd.read_json("dataset/data/dev.json").head(1200)

In [5]:
# %% Make prompt

def make_prompt(question, context, supporting_facts, top_k=None):
    supporting_titles = {title for (title, _) in supporting_facts}

    def is_oracle(evidence):
        (title,_) = evidence
        return title in supporting_titles

    contexts = list(filter(is_oracle, context))
    random.shuffle(contexts)

    if top_k is not None:
        negatives = list(filterfalse(is_oracle, context))
        random.shuffle(negatives)
        contexts.extend(negatives)
        contexts = contexts[:top_k]

    contexts = [
        f"Context {i}: [{title}] {' '.join(texts)}"
        for i, (title, texts) in enumerate(contexts, start=1)
    ]

    return "\n\n".join(
        (
            "Please answer the given question based on the given contexts below.",
            *contexts,
            f"Question: {question}",
            "Constraint: Don't give any explanations and use MAX 5 tokens in your response. No yapping.",
        )
    )

In [None]:
# %% Run experiment and save results

results = []
for row in tqdm(df.itertuples()):
    for top_k in (None, 1, 3, 5):
        prompt = make_prompt(row.question, row.context, row.supporting_facts, top_k=top_k)
        prediction = ollama.generate(model="llama3-temp0.5", prompt=prompt)
        results.append((top_k, row.question, row.answer, prediction["response"]))

pd.DataFrame(results, columns=("top_k", "question", "answer", "prediction")).to_csv("oracle_prediction_results.csv", index=False)

In [6]:
# %% Reload results

results_df = pd.read_csv("oracle_prediction_results.csv")
results_df["top_k"] = results_df["top_k"].astype(str).replace("nan", "n/a")
results_df

Unnamed: 0,top_k,question,answer,prediction
0,,Who is the mother of the director of film Poli...,Małgorzata Braunek,Małgorzata Braunek
1,1.0,Who is the mother of the director of film Poli...,Małgorzata Braunek,Małgorzata Braunek
2,3.0,Who is the mother of the director of film Poli...,Małgorzata Braunek,Małgorzata Braunek
3,5.0,Who is the mother of the director of film Poli...,Małgorzata Braunek,Małgorzata Braunek
4,,"Which film came out first, Blind Shaft or The ...",The Mask Of Fu Manchu,The Mask of Fu Manchu.
...,...,...,...,...
4795,5.0,Do both My Friend From The Park and Punks (Fil...,no,"No, they don't."
4796,,Which country the director of film Romanoff An...,United Kingdom,England
4797,1.0,Which country the director of film Romanoff An...,United Kingdom,England
4798,3.0,Which country the director of film Romanoff An...,United Kingdom,American


In [7]:
# %% Compute metrics

metrics_df = pd.concat((
    results_df,
    pd.DataFrame(
        results_df.apply(lambda row: calculate_metrics(str(row["prediction"]), str(row["answer"])), axis=1).tolist(),
        columns=("Precision", "Recall", "F1")
    )),
    axis=1
)

metrics_df = pd.DataFrame({
    "EM": results_df.groupby("top_k").apply(
        lambda df: exact_match.compute(
            predictions=df["prediction"].tolist(),
            references=df["answer"].tolist(),
            ignore_case=True,
            ignore_punctuation=True,
        ).get("exact_match")
    ),
    **metrics_df.groupby("top_k").mean(numeric_only=True),
}).reset_index()

metrics_df

  "EM": results_df.groupby("top_k").apply(


Unnamed: 0,top_k,EM,Precision,Recall,F1
0,1.0,0.335833,0.434792,0.440234,0.425343
1,3.0,0.405833,0.563333,0.617513,0.570063
2,5.0,0.375,0.537408,0.611012,0.551896
3,,0.484167,0.63894,0.696369,0.644686


In [12]:
print(metrics_df.to_latex(float_format="%.4f", index=False))

\begin{tabular}{lrrrr}
\toprule
top_k & EM & Precision & Recall & F1 \\
\midrule
1.0 & 0.3358 & 0.4348 & 0.4402 & 0.4253 \\
3.0 & 0.4058 & 0.5633 & 0.6175 & 0.5701 \\
5.0 & 0.3750 & 0.5374 & 0.6110 & 0.5519 \\
n/a & 0.4842 & 0.6389 & 0.6964 & 0.6447 \\
\bottomrule
\end{tabular}

