In [1]:
import json
import logging
import time
from typing import List, Union, Tuple
from indo_eval.llm import AnyOpenAILLM
from indo_eval.retrieval.retrieval import Qdrant, KeywordMatching
from indo_eval.llm import gen_model
logging.basicConfig(filename='log.txt', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')



In [2]:
retrieval_method="kw_matching"
if retrieval_method == "qdrant":
    retrieval = Qdrant()
elif retrieval_method == "kw_matching":
    retrieval = KeywordMatching()
else:
    raise ValueError("Invalid retrieval method")

# examples showcasing that kw matching retrieval is not robust to long queries 
# retrieval.format_sources(retrieval.search("Start date for project 'Blue Lagoon Luxury Resort"))
# retrieval.format_sources(retrieval.search("The query is looking up the starting date for a particular project, and the name of that specific project is given as 'Blue Lagoon Luxury Resort'"))

gen_model.temperature = 0
def retrieval_augmented_response(query: str, answer:  Union[List, Tuple] = None, skip_gen_if_retrieval_failed=False, retrieved_documents=None):
    result = {"query": query, "answer": answer, "retrieved_documents": [], "gpt_response": None, "retrieval_failed": 0}
    # Step 1 - Retrieval
    if not retrieved_documents:
        items = retrieval.search(query)
        # top 3 items
        result['top3'] = [item.dict(include={'id', 'score'}) for item in retrieval.rank(items)[:3]]
        retrieved_documents = """### Sources:\n"""+retrieval.format_sources(items)
    result["retrieved_documents"] = retrieved_documents + "\n\n### Question:\n" + query

    # Step 2: Check if the answer exists in the retrieved documents
    for ans in answer:
        if ans not in result["retrieved_documents"]:
            result['retrieval_failed'] += 1
    if skip_gen_if_retrieval_failed and result['retrieval_failed']:
        return result

    # Step 3: Call OpenAI Chat API to generate a response
    result["gpt_response"] = gen_model(result["retrieved_documents"], temperature=0)

    return result
# result = retrieval_augmented_response("Display location for a project that has the same name as 5 Martin Place", ["Sydney"])

def check_answer(query, true_answer:  Union[List, Tuple], gpt_response):
    prompt = """Evaluate the accuracy of the given response in relation to the true answer for the specified query. After evaluating, provide a judgement as either "Correct" or "Incorrect" based on whether the ##Given Response## accurately matches the ##True Answer##.

    ##Query##: {}
    ##True Answer##: {}
    ##Given Response##: {}
    ##Judgement##:
    """.format(query, json.dumps(true_answer), gpt_response)
    gpt4_llm = AnyOpenAILLM(model_name = "gpt4-short")
    return gpt4_llm(prompt)


In [17]:
# parameters
linguistic_control = "long"
fix_context = False
suffix = ""#"_project_company"
long_query_suffix = '' # "Thank you very much for providing the information.  If you can correctly answer my question, it would be really helpful for my work in Melbourne."
log = False
maximum_test = 1000000
reask = False
reask_file_name = "eval_data/results_shortest_queries_full_doc_db.json"
recover = True
recover_file_name = "eval_data/results_long_queries_all_doc_queries.json"
skip = False
skip_file_name = "eval_data/results_long_queries_all_doc_queries.json"

file_name = f'eval_data/results_{linguistic_control}_queries{suffix}'
if long_query_suffix:
    file_name += "_with_suffix"
if fix_context:
    file_name += "_same_retrieved_document"

# recover
if recover:
    with open(recover_file_name,  'r') as f:
        results_exist = json.load(f)

# reask
if reask:
    with open(reask_file_name, "r") as f:
        full_results = json.load(f)
    reask_query_keywords = ["Champs-Élysées Boutique Hotel",  "Advanced Health Research Facility"]

# skip
if skip:
    with open(skip_file_name, "r") as f:
        full_results = json.load(f)

# load evaluation data
with open(f'eval_data/answers_to_text_queries_all_{linguistic_control}_templates{suffix}.json') as f:
    answers_to_text_queries = json.load(f)

# fix context
if fix_context:
    with open("eval_data/tag_to_retrieved_context_from_correctly_predicted_long_query_dict.json") as f:
        tag_to_retrieved_context_from_long_query_dict = json.load(f)

# evaluate
results = {entity: [] for entity in answers_to_text_queries}
for entity in answers_to_text_queries:
    if log:
        logging.info(f"Entity Type: {entity}")
    print(f"Entity Type: {entity}")
    if entity in ["Company"]:
        continue

    query_tag = -1
    for answer_queries_pair in answers_to_text_queries[entity]:
        num_results = 0
        for answer, query_list in answer_queries_pair:
            query_tag += 1
            answer = eval(answer) 
            assert len(answer)==1 
            answer = answer[0] # no query multiplicity
            if fix_context:
                try:
                    retrieved_documents = tag_to_retrieved_context_from_long_query_dict[entity][str(query_tag)][0][2]
                except: # IndexError: no retrieved doc leading to correct prediction
                    continue
            else: 
                retrieved_documents = None
            for query in query_list:
                if reask and not sum([1 for query_keyword in reask_query_keywords if query_keyword in query]): # keyword not in query:
                    result = full_results[entity][num_results]
                elif skip and sum([1 for query_keyword in reask_query_keywords if query_keyword in query]): # keyword in query:
                    continue  
                elif recover and [result_exist for result_exist in results_exist[entity] if result_exist['query'] == query]:
                    result = [result_exist for result_exist in results_exist[entity] if result_exist['query'] == query][0]
                else:
                    result = retrieval_augmented_response(query+long_query_suffix, answer, retrieved_documents=retrieved_documents)
                    result["judgement"] = check_answer(result["query"], result["answer"], result["gpt_response"])
                    result['query_tag'] = query_tag
                results[entity].append(result)        
                maximum_test -= 1
                num_results += 1
                if maximum_test <= 0:
                    break
                if log:
                    logging.info(f"Query: {query}")
                    logging.info(f"\tAnswer: {answer}") # type(answer) -> 'tuple'
                    logging.info(f"\tGenerated: {result['gpt_response']}")
                    logging.info(f"\tJudgement: {result['judgement']}")
                print(f"Query: {query}")
                print(f"\tAnswer: {answer}")
                print(f"\tGenerated: {result['gpt_response']}")
                print(f"\tJudgement: {result['judgement']}")

            



Entity Type: Project
Query: "The query is looking up the starting date for a particular project, and the name of that specific project is given as 'Blue Lagoon Luxury Resort'. 
	Answer: ('2023-04-01',)
	Generated: The starting date for the Blue Lagoon Luxury Resort project is mentioned in the project overview as April 1, 2023.
	Judgement: Correct
Query: "Let's find out the date when the project named 'Blue Lagoon Luxury Resort' was initiated according to the records held in our project database. 
	Answer: ('2023-04-01',)
	Generated: According to the records in our project database, the Blue Lagoon Luxury Resort project was initiated on the 1st of April, 2023.
	Judgement: Correct
Query: "In reference to our project database, we are interested in locating the details of when a certain project, identified as 'Blue Lagoon Luxury Resort', officially commenced."
	Answer: ('2023-04-01',)
	Generated: According to the project overview provided, the Blue Lagoon Luxury Resort officially commenced

/var/folders/z_/xphnyhxs03sg7p8v5dgkr10w0000gn/T/ipykernel_36468/3048550651.py:20: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.4/migration/
  result['top3'] = [item.dict(include={'id', 'score'}) for item in retrieval.rank(items)[:3]]


AttributeError: usage

In [14]:
# save results
with open(file_name+'.json', 'w') as fp:
    json.dump(results, fp)

In [15]:
len(results['Project'])

120