In [42]:
import json
import os
import sys
import random
from tqdm import tqdm
sys.path.append('..')
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
import logging
from typing import List, Union, Tuple
from rag_systems.retrieval.retrieval import Qdrant, KeywordMatching
from grammar.eval.match import check_qa_answer
logging.basicConfig(filename='log.txt', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# read documents from benchmarks/aurp/aurp-database/documents.json
root_dir = "spider_closed/"
documents = json.load(open(f'{root_dir}/retrieval_database/company_documents.json'))

In [43]:
retrieval = KeywordMatching(documents)
def retrieve(query: str, answer:  Union[List, Tuple] = None, true_document_ids:List[int]=None, retrieved_documents=None):
    result = {
        "query": query, 
        "answer": answer, 
        "true_document_ids": set([int(i) for i in true_document_ids]),
        "retrieved_documents": [], 
        "gpt_response": None, 
        "retrieval_judgement": 1 # correct
    }
    # 1 - Retrieval
    if not retrieved_documents:
        items = retrieval.search(query)
        result['retrieved_document_ids'] = set([int(retrieval.rank(items)[0].id)])
        retrieved_documents: str = retrieval.format_sources(items)
    result["retrieved_documents"] = retrieved_documents 

    # 2 - Check if the answer exists in the retrieved documents
    if result['retrieved_document_ids']!=result['true_document_ids']:
        result['retrieval_judgement'] = 0 # incorrect
    return result

from grammar.generator import Generator
from grammar.llm import gen_model
gen_model.temperature = 0
class RetrievalAugGen(Generator):
    verbalizer = {
        "short": "",
        "long": ""
    }
    def __init__(self, llm=None, verbalize_attrs=''):
        llm = llm or gen_model
        super().__init__( llm=gen_model, verbalize_attrs=verbalize_attrs)

    def _generate(self, context_query:tuple, num_generations=None, verbose=False):
        context, query = context_query
        prompt = """### Sources:\n"""+ context + "\n\n### Question:\n" + query
        return [gen_model(prompt, temperature=0)]

In [50]:
linguistic_control = "short"
ragen = RetrievalAugGen.from_file(root_dir=root_dir, verbalize_attrs=linguistic_control)
# load evaluation data
with open(f'{root_dir}/QADataGenerator/{linguistic_control}_with_ids.json') as f:
    answers_to_text_queries = json.load(f)

# reduce the dimension over SQL templates 
semantics_groups = {answer: query_list for i in range(len(answers_to_text_queries)) for answer, query_list in answers_to_text_queries[i]}
print(f"Number of semantics groups: {len(semantics_groups)}")

Number of semantics groups: 57


In [51]:
results = [] 
group_tag = -1

for answer, query_list in semantics_groups.items():
    group_tag += 1 
    answer = eval(answer)[0]  # only 1 answer; no query multiplicity
    answer_txt = answer[0] 
    print(f"{len(query_list)} queries in group {group_tag}")
    if len(answer) >= 2: # known true documents
        true_document_ids = [int(i) for i in answer[1:]]
    for query in query_list:
        if linguistic_control == "long":
            # 0.3 chance to select the example for eval; otherwise, skip
            if group_tag >= 38 and random.random() > 0.3:
                continue
        else:
            # 0.75 chance to select the example for eval; otherwise, skip
            if random.random() > 0.75:
                continue
        result = retrieve(query, answer_txt, true_document_ids=true_document_ids)
        result["gpt_response"] =  ragen.generate((result["retrieved_documents"], result["query"]), verbose=True)
        result['query_tag'] = group_tag
        results.append(result)    
    


10 queries in group 0
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
10 queries in group 1
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` 

In [17]:
# ragen.save(root_dir=root_dir, override=False)

In [52]:
# ensure json serializable
for result in results:
    result['true_document_ids'] = list(result['true_document_ids'])
    result['retrieved_document_ids'] = list(result['retrieved_document_ids'])

# save results
with open(f'{root_dir}/eval_results/results_{linguistic_control}.json', 'w') as f:
    json.dump(results, f, indent=4)