In [120]:
from copy import deepcopy
import json
from typing import List, Optional, Tuple, Type, TypeVar
from tqdm import tqdm
import time
from pydantic.dataclasses import dataclass

In [121]:
import google.generativeai as genai
genai.configure(api_key = "") # API Key is intentionally removed

In [122]:
# Generation config
config = {"temperature": 0.0, "top_p": 1}

# Safety config
safety = [
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUAL",
        "threshold": "BLOCK_NONE",
    },
]

In [123]:
model_name = 'gemini-pro'
model = genai.GenerativeModel('gemini-pro', safety_settings = safety, generation_config = config)

In [124]:
# The dataclass code is taken from the author's github repository.
T = TypeVar("T")

@dataclass(frozen=True)
class Document:
    title: str
    text: str
    id: Optional[str] = None
    score: Optional[float] = None
    hasanswer: Optional[bool] = None
    isgold: Optional[bool] = None
    original_retrieval_index: Optional[int] = None

    @classmethod
    def from_dict(cls: Type[T], data: dict) -> T:
        data = deepcopy(data)
        if not data:
            raise ValueError("Must provide data for creation of Document from dict.")
        id = data.pop("id", None)
        score = data.pop("score", None)
        # Convert score to float if it's provided.
        if score is not None:
            score = float(score)
        return cls(**dict(data, id=id, score=score))

In [125]:
def get_qa_prompt(question: str, documents: List[Document], query_aware_contextualization: bool):
    if query_aware_contextualization:
        prompt_path = "./prompting/qa_with_query_aware_contextualization.prompt"
    else:
        prompt_path = "./prompting/qa.prompt"
        
    with open(prompt_path) as f:
        prompt_template = f.read().rstrip("\n")

    # Format the documents into strings
    formatted_documents = []
    for document_index, document in enumerate(documents):
        formatted_documents.append(f"Document [{document_index+1}](Title: {document.title}) {document.text}")
        
    
    return prompt_template.format(question=question, search_results="\n".join(formatted_documents))

In [126]:
def get_oracle_responses(inp : str, out : str, query_aware_contextualization: bool):
    prompts = []
    responses = []
    correct_responses = []
    
    with open(inp) as fin:
        for line in tqdm(fin):
            input_example = json.loads(line)
            
            # Getting question and correct answer
            question = input_example["question"]
            correct_answer = input_example["answers"]
            
            documents = []
            for ctx in deepcopy(input_example["ctxs"]):
                documents.append(Document.from_dict(ctx))
            
            qa_prompt = get_qa_prompt(question, documents, query_aware_contextualization)
            
            prompts.append(qa_prompt)
            correct_responses.append(correct_answer)
            
            response = model.generate_content(qa_prompt)
            try:
                res = response.text
            except Exception as e:
                res = "Model refused to answer!"
                print(f'{type(e).__name__}: {e}')
                print(qa_prompt)
                print(response.prompt_feedback)
                
            responses.append(res)
            
            time.sleep(1) # Google has set rate limit of 60 requests per minute
             
            
    with open(out, "w") as f:
        for prompt, response, correct_answer in zip(prompts, responses, correct_responses):
            output = {}

            output["model_prompt"] = prompt
            output["model_answer"] = response
            output["model"] = model_name
            output["correct_answer"] = correct_answer

            f.write(json.dumps(output) + "\n")

In [127]:
input_path = "./qa_data/nq-open-oracle.json"
output_path = "./responses/gemini_qa/gemini_oracle_responses.jsonl"

In [128]:
get_oracle_responses(input_path, output_path, False)

In [129]:
input_path = "./qa_data/nq-open-oracle.json"
output_path = "./responses/gemini_qa/gemini_oracle_QAC_responses.jsonl"

In [None]:
get_oracle_responses(input_path, output_path, True)