In [1]:
from copy import deepcopy
import json
from typing import List, Optional, Tuple, Type, TypeVar
from tqdm import tqdm
import time
from pydantic.dataclasses import dataclass

In [2]:
model="llama2-70b"

In [3]:
# get a token: https://replicate.com/account
from getpass import getpass
import os

REPLICATE_API_TOKEN = getpass()
os.environ["REPLICATE_API_TOKEN"] = "r8_IwYEZpurE3QujsWwrE2t7EaLw5cRqo81AjBOn"

In [4]:
from groq import Groq
client = Groq(
    api_key="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"
)

In [5]:
# The dataclass code is taken from the author's github repository.
T = TypeVar("T")

@dataclass(frozen=True)
class Document:
    title: str
    text: str
    id: Optional[str] = None
    score: Optional[float] = None
    hasanswer: Optional[bool] = None
    isgold: Optional[bool] = None
    original_retrieval_index: Optional[int] = None

    @classmethod
    def from_dict(cls: Type[T], data: dict) -> T:
        data = deepcopy(data)
        if not data:
            raise ValueError("Must provide data for creation of Document from dict.")
        id = data.pop("id", None)
        score = data.pop("score", None)
        # Convert score to float if it's provided.
        if score is not None:
            score = float(score)
        return cls(**dict(data, id=id, score=score))

In [6]:
def get_qa_prompt(question: str, documents: List[Document], query_aware_contextualization: bool):
    if query_aware_contextualization:
        prompt_path = "./prompting/qa_with_query_aware_contextualization.prompt"
    else:
        prompt_path = "./prompting/qa.prompt"
        
    with open(prompt_path) as f:
        prompt_template = f.read().rstrip("\n")

    # Format the documents into strings
    formatted_documents = []
    for document_index, document in enumerate(documents):
        formatted_documents.append(f"Document [{document_index+1}](Title: {document.title}) {document.text}")
        
    
    return prompt_template.format(question=question, search_results="\n".join(formatted_documents))

In [7]:
def get_multi_doc_qa_responses(inp : str, out : str, query_aware_contextualization: bool):
    prompts = []
    correct_responses = []

    with open(inp) as fin:
        for line in tqdm(fin):
            input_example = json.loads(line)
            # Getting question and correct answer
            question = input_example["question"]
            correct_answer = input_example["answers"]
            
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))
            
            qa_prompt = get_qa_prompt(question, documents, query_aware_contextualization)
            
            prompts.append(qa_prompt)
            correct_responses.append(correct_answer)
            
            chat_completion = client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": qa_prompt,
                    }
                ],
                model="llama2-70b-4096",  
            )
            response = chat_completion.choices[0].message.content
            
            with open(out, "a") as f:
                output = {}

                output["model_prompt"] = qa_prompt
                output["model_answer"] = response
                output["model"] = model
                output["correct_answer"] = correct_answer

                f.write(json.dumps(output) + "\n")
            
            time.sleep(5) 

In [None]:
input_paths = [
    "./qa_data/10_total_documents/nq-open-10_total_documents_gold_at_0.jsonl",
    "./qa_data/10_total_documents/nq-open-10_total_documents_gold_at_4.jsonl",
    "./qa_data/10_total_documents/nq-open-10_total_documents_gold_at_9.jsonl"]
output_paths = [
    "./responses/llama2_qa/llama2_10_doc_at_0_responses.jsonl",
    "./responses/llama2_qa/llama2_10_doc_at_4_responses.jsonl",
    "./responses/llama2_qa/llama2_10_doc_at_9_responses.jsonl"]

In [None]:
for i in range(len(input_paths)):
    get_multi_doc_qa_responses(input_paths[i], output_paths[i], False)

In [None]:
input_paths = [
    "./qa_data/10_total_documents/nq-open-10_total_documents_gold_at_0.jsonl",
    "./qa_data/10_total_documents/nq-open-10_total_documents_gold_at_4.jsonl",
    "./qa_data/10_total_documents/nq-open-10_total_documents_gold_at_9.jsonl"]
output_paths = [
    "./responses/llama2_qa/llama2_10_doc_at_0_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_10_doc_at_4_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_10_doc_at_9_QAC_responses.jsonl"]

In [None]:
for i in range(len(input_paths)):
    get_multi_doc_qa_responses(input_paths[i], output_paths[i], True)

In [None]:
input_paths = [
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_0.jsonl",
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_4.jsonl",
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_9.jsonl",
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_14.jsonl",
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_19.jsonl"]
output_paths = [
    "./responses/llama2_qa/llama2_20_doc_at_0_responses.jsonl",
    "./responses/llama2_qa/llama2_20_doc_at_4_responses.jsonl",
    "./responses/llama2_qa/llama2_20_doc_at_9_responses.jsonl",
    "./responses/llama2_qa/llama2_20_doc_at_14_responses.jsonl",
    "./responses/llama2_qa/llama2_20_doc_at_19_responses.jsonl"]

In [None]:
for i in range(len(input_paths)):
    get_multi_doc_qa_responses(input_paths[i], output_paths[i], False)

In [None]:
input_paths = [
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_0.jsonl",
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_4.jsonl",
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_9.jsonl",
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_14.jsonl",
    "./qa_data/30_total_documents/nq-open-20_total_documents_gold_at_19.jsonl"]
output_paths = [
    "./responses/llama2_qa/llama2_20_doc_at_0_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_20_doc_at_4_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_20_doc_at_9_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_20_doc_at_14_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_20_doc_at_19_QAC_responses.jsonl"]

In [None]:
for i in range(len(input_paths)):
    get_multi_doc_qa_responses(input_paths[i], output_paths[i], True)

In [None]:
input_paths = [
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_4.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_9.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_14.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_19.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_24.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_29.jsonl"]
output_paths = [
    "./responses/llama2_qa/llama2_30_doc_at_0_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_4_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_9_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_14_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_19_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_24_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_29_responses.jsonl"]

In [None]:
for i in range(len(input_paths)):
    get_multi_doc_qa_responses(input_paths[i], output_paths[i], False)

In [9]:
input_paths = [
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_0.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_4.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_9.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_14.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_19.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_24.jsonl",
    "./qa_data/30_total_documents/nq-open-30_total_documents_gold_at_29.jsonl"]
output_paths = [
    "./responses/llama2_qa/llama2_30_doc_at_0_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_4_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_9_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_14_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_19_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_24_QAC_responses.jsonl",
    "./responses/llama2_qa/llama2_30_doc_at_29_QAC_responses.jsonl"]

In [None]:
for i in range(len(input_paths)):
    get_multi_doc_qa_responses(input_paths[i], output_paths[i], True)