In [1]:
from copy import deepcopy
import json
from typing import List, Optional, Tuple, Type, TypeVar
from tqdm import tqdm
import time

In [2]:
import google.generativeai as genai
genai.configure(api_key = "") # API Key is intentionally removed

In [3]:
model_name = 'gemini-pro'
model = genai.GenerativeModel('gemini-pro')
# at the time of writing this code gemini-pro is same as gemini-1.0-pro

In [4]:
def get_kv_retrieval_prompt(data: List[Tuple[str, str]], key: str, query_aware_contextualization: bool):
    if query_aware_contextualization:
        with open("./prompting/kv_retrieval_with_query_aware_contextualization.prompt") as f:
            prompt_template = f.read().rstrip("\n")
    else:
        with open("./prompting/kv_retrieval.prompt") as f:
            prompt_template = f.read().rstrip("\n")
    
    # Format the KV data into a string
    formatted_kv_records = ""
    for index, record in enumerate(data):
        start_character = "{" if index == 0 else " "
        data_string = f'"{record[0]}": "{record[1]}"'
        end_character = ",\n" if index != len(data) - 1 else "}"
        formatted_kv_records += start_character + data_string + end_character
        
    return prompt_template.format(formatted_kv_records=formatted_kv_records, key=key)

In [5]:
def get_responses(inp : str, out : str, correct_index : int, query_aware_contextualization: bool):
    prompts = []
    responses = []
    correct_responses = []
    
    with open(inp) as fin:
        for line in tqdm(fin):
            input_example = json.loads(line)
            
            # Getting the kv records, correct key & value
            ordered_kv_records = deepcopy(input_example["ordered_kv_records"])
            key = input_example["key"]
            value = input_example["value"]

            # Making gold_index to have the correct key-value
            original_kv_index = ordered_kv_records.index([key, value])
            original_kv = ordered_kv_records.pop(original_kv_index)
            ordered_kv_records.insert(correct_index, original_kv)

            kv_prompt = get_kv_retrieval_prompt(
                data=ordered_kv_records, key=key, query_aware_contextualization=query_aware_contextualization
            )
            
            prompts.append(kv_prompt)
            correct_responses.append(value)
            
            response = model.generate_content(kv_prompt)
            responses.append(response.text)
            
            # time.sleep(1) # Google has set rate limit of 60 requests per minute
            
    with open(out, "w") as f:
        for prompt, response, correct_answer in zip(prompts, responses, correct_responses):
            output = {}

            output["model_prompt"] = prompt
            output["model_answer"] = response
            output["model"] = model_name
            output["correct_answer"] = correct_answer

            f.write(json.dumps(output) + "\n")

In [6]:
input_paths = [
    "./kv_retrieval_data/kv-retrieval-75_keys.jsonl",
    "./kv_retrieval_data/kv-retrieval-140_keys.jsonl",
    "./kv_retrieval_data/kv-retrieval-300_keys.jsonl",
    ]
output_paths = [
    [ "./responses/gemini_kv/gemini_kv_75_key_at_0_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_75_key_at_24_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_75_key_at_49_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_75_key_at_74_responses.jsonl"], 
    [ "./responses/gemini_kv/gemini_kv_140_key_at_0_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_140_key_at_34_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_140_key_at_69_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_140_key_at_104_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_140_key_at_139_responses.jsonl"],
    [ "./responses/gemini_kv/gemini_kv_300_key_at_0_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_300_key_at_49_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_300_key_at_99_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_300_key_at_149_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_300_key_at_199_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_300_key_at_249_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_300_key_at_299_responses.jsonl"]
    ]
                
correct_indices = [[0, 24, 49, 74], [0, 34, 69, 104, 139], [0, 49, 99, 149, 199, 249, 299]]

In [None]:
for i in range(len(input_paths)):
    for j in range(len(output_paths[i])):
        # print(input_paths[i], output_paths[i][j], correct_indices[i][j])
        get_responses(input_paths[i], output_paths[i][j], correct_indices[i][j], False)

In [None]:
input_paths = [
    "./kv_retrieval_data/kv-retrieval-75_keys.jsonl",
    "./kv_retrieval_data/kv-retrieval-140_keys.jsonl",
    "./kv_retrieval_data/kv-retrieval-300_keys.jsonl",
    ]
output_paths = [
    [ "./responses/gemini_kv/gemini_kv_QAC_75_key_at_0_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_75_key_at_24_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_75_key_at_49_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_75_key_at_74_responses.jsonl"], 
    [ "./responses/gemini_kv/gemini_kv_QAC_140_key_at_0_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_140_key_at_34_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_140_key_at_69_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_140_key_at_104_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_140_key_at_139_responses.jsonl"],
    [ "./responses/gemini_kv/gemini_kv_QAC_300_key_at_0_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_300_key_at_49_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_300_key_at_99_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_300_key_at_149_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_300_key_at_199_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_300_key_at_249_responses.jsonl",
      "./responses/gemini_kv/gemini_kv_QAC_300_key_at_299_responses.jsonl",]
    ]
                
correct_indices = [[0, 24, 49, 74],[0, 34, 69, 104, 139], [0, 49, 99, 149, 199, 249, 299]]

In [None]:
for i in range(len(input_paths)):
    for j in range(len(output_paths[i])):
        # print(input_paths[i], output_paths[i][j], correct_indices[i][j])
        get_responses(input_paths[i], output_paths[i][j], correct_indices[i][j], True)