In [75]:
from copy import deepcopy
import json
from typing import List, Optional, Tuple, Type, TypeVar
from tqdm import tqdm
import time

In [76]:
import google.generativeai as genai
genai.configure(api_key = "") # API Key is intentionally removed

In [77]:
# Generation config
config = {"temperature": 0.0, "top_p": 1}

# Safety config
safety = [
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUAL",
        "threshold": "BLOCK_NONE",
    },
]

In [78]:
model_name = 'gemini-pro'
model = genai.GenerativeModel('gemini-pro', safety_settings = safety, generation_config = config)
# at the time of writing this code gemini-pro is same as gemini-1.0-pro

In [79]:
def get_closedbook_qa_prompt(question: str):
    with open("./prompting/closedbook_qa.prompt") as f:
        prompt_template = f.read().rstrip("\n")

    return prompt_template.format(question=question)

In [80]:
def get_closedbook_responses(inp, out):
    prompts = []
    responses = []
    correct_responses = []
    
    with open(inp) as fin:
        for line in tqdm(fin):
            input_example = json.loads(line)
            
            # Getting question and correct answer
            question = input_example["question"]
            correct_answer = input_example["answers"]
            
            qa_closedbook_prompt = get_closedbook_qa_prompt(question)
            
            prompts.append(qa_closedbook_prompt)
            correct_responses.append(correct_answer)
            
            response = model.generate_content(qa_closedbook_prompt)
            try:
                res = response.text
            except Exception as e:
                res = "Model refused to answer!"
                print(f'{type(e).__name__}: {e}')
                print(qa_closedbook_prompt)
                print(response.prompt_feedback)
                
            responses.append(res)
            
            time.sleep(0.5) # Google has set rate limit of 60 requests per minute
             
            
    with open(out, "w") as f:
        for prompt, response, correct_answer in zip(prompts, responses, correct_responses):
            output = {}

            output["model_prompt"] = prompt
            output["model_answer"] = response
            output["model"] = model_name
            output["correct_answer"] = correct_answer

            f.write(json.dumps(output) + "\n")

In [81]:
input_path = "./qa_data/nq-open-oracle.json"
output_path = "./responses/gemini_qa/gemini_closedbook_responses.jsonl"

In [None]:
get_closedbook_responses(input_path, output_path)