In [6]:
import json
import sys
sys.path.append('..')
from grammar.eval.result import RAGResult
from grammar.eval.tag_group import TaggedGroup 
from grammar.eval.match import SemanticsMatch
from tqdm import tqdm

In [7]:
def init_eval_results(file_path):
    with open(file_path) as f:
        results = json.load(f)
    
    # convert json to dataclass to a new list
    eval_results = []    
    for result in results:
        if isinstance(result['gpt_response'], list):
            result['gpt_response'] = result['gpt_response'][0]
        assert result['gpt_response'] is None or isinstance(result['gpt_response'], str), f"The gpt_response should be a string, not {type(result['gpt_response'])}"
            
            
        eval_result = RAGResult(query=result['query'], answer=result['answer'], \
                                    gpt_response=result['gpt_response'], \
                                    true_document_ids=result['true_document_ids'], query_tag=result['query_tag'], \
                                    retrieved_document_ids=result['retrieved_document_ids'], \
                                    retrieval_judgement=result['retrieval_judgement'], \
                                    closed_domain=closed_domain)
        eval_results.append(eval_result)
    return eval_results

def get_eval_results(eval_results, linguistic_attr, root_dir, file_path, test_mode=False):
    
    tagged_group = TaggedGroup(eval_results)
    semnatics_match = SemanticsMatch.from_file(root_dir=root_dir, verbalize_attrs=linguistic_attr)
    if test_mode:
        semnatics_match.set_test_mode(True)

    for eval_result in tqdm(eval_results):
        # sleep for 20 seconds after 9 examples
        # if results.index(result) % 9 == 0 and results.index(result) != 0:
        #     print("Sleeping for 20 seconds")
        #     time.sleep(20)
        #     print("Waking up")
        eval_result.judge_retrieval_response(tagged_group=tagged_group, method='use_exist')
        eval_result.judge_rag_response(semnatics_match)

    num_retrieval_failure = sum([result.retrieval_judgement==0 for result in eval_results])
    print(f"Retrieval failed in {num_retrieval_failure} out of {len(eval_results)} examples")
    num_rag_failure = sum([result.judgement=="Incorrect" for result in eval_results])
    print(f"RAG failed in {num_rag_failure} out of {len(eval_results)} examples")
    if test_mode:
        semnatics_match.save(root_dir=f'{root_dir}', override=True)
    # semnatics_match.llm.gpt_usage_record.write_usage(model_name='chatgptk' )

    if not test_mode:
        # save results
        results = [result.asdict() for result in eval_results]
        # ensure json serializable
        for result in results:
            result['true_document_ids'] = list(result['true_document_ids'])
            result['retrieved_document_ids'] = list(result['retrieved_document_ids'])
        with open(file_path, 'w') as f:
            json.dump(results, f, indent=4)

    return eval_results, tagged_group

root_dir = 'spider'
closed_domain = True
test_mode = True
print('============ Balanced ============')
results_long_balanced = init_eval_results(f'{root_dir}/eval_results/results_long_balanced.json')
results_short_balanced = init_eval_results(f'{root_dir}/eval_results/results_short_balanced.json')
results_long_balanced, metric_long_balanced = get_eval_results(results_long_balanced, 'long', root_dir, file_path=f'{root_dir}/eval_results/results_long_balanced.json', test_mode=test_mode)
results_short_balanced, metric_short_balanced = get_eval_results( results_short_balanced, 'short', root_dir, file_path=f'{root_dir}/eval_results/results_short_balanced.json', test_mode=test_mode)
print('============ Imbalanced ============')
results_long = init_eval_results(f'{root_dir}/eval_results/results_long.json')
results_short = init_eval_results(f'{root_dir}/eval_results/results_short.json')
results_long, metric_long = get_eval_results(results_long, 'long', root_dir, file_path=f'{root_dir}/eval_results/results_long.json', test_mode=test_mode)
results_short, metric_short = get_eval_results(results_short, 'short', root_dir, file_path=f'{root_dir}/eval_results/results_short.json', test_mode=test_mode)



100%|██████████| 570/570 [00:00<00:00, 593680.97it/s]


The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to gen

100%|██████████| 570/570 [00:00<00:00, 805238.56it/s]


The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to gen

100%|██████████| 436/436 [00:00<00:00, 797173.73it/s]


The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to gen

100%|██████████| 397/397 [00:00<00:00, 769472.59it/s]

The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to gen




In [8]:
# long vs short
print('Baseline Accuracy for short: ', metric_short.get_accuracy())
print('Baseline Accuracy for long: ', metric_long.get_accuracy())

print('Accuracy (Remove LLM Errors) for short: ', metric_short.get_accuracy(for_retrieval=True))
print('Accuracy (Remove LLM Errors) for long: ', metric_long.get_accuracy(for_retrieval=True))

print('Robustness (Removing Gap Examples) for short: ', metric_short.get_robustness())
print('Robustness (Removing Gap Examples) for long: ', metric_long.get_robustness())

print('Robustness (Removing LLM Errors & Gap Examples) for short: ', metric_short.get_robustness(for_retrieval=True))
print('Robustness (Removing LLM Errors & Gap Examples) for long: ', metric_long.get_robustness(for_retrieval=True))



Baseline Accuracy for short:  0.34760705289672544
Baseline Accuracy for long:  0.3463302752293578
Accuracy (Remove LLM Errors) for short:  0.3677581863979849
Accuracy (Remove LLM Errors) for long:  0.3532110091743119
Robustness (Removing Gap Examples) for short:  0.6699029126213593
Robustness (Removing Gap Examples) for long:  0.5189003436426117
Robustness (Removing LLM Errors & Gap Examples) for short:  0.6747572815533981
Robustness (Removing LLM Errors & Gap Examples) for long:  0.5223367697594502


In [13]:
# long vs short
print('Baseline Accuracy for short: ', metric_short_balanced.get_accuracy())
print('Baseline Accuracy for long: ', metric_long_balanced.get_accuracy())

print('Accuracy (Remove LLM Errors) for short: ', metric_short_balanced.get_accuracy(for_retrieval=True))
print('Accuracy (Remove LLM Errors) for long: ', metric_long_balanced.get_accuracy(for_retrieval=True))

print('Robustness (Removing Gap Examples) for short: ', metric_short_balanced.get_robustness())
print('Robustness (Removing Gap Examples) for long: ', metric_long_balanced.get_robustness())

print('Robustness (Removing LLM Errors & Gap Examples) for short: ', metric_short_balanced.get_robustness(for_retrieval=True))
print('Robustness (Removing LLM Errors & Gap Examples) for long: ', metric_long_balanced.get_robustness(for_retrieval=True))



Baseline Accuracy for short:  0.27388535031847133
Baseline Accuracy for long:  0.2760084925690021
Accuracy (Remove LLM Errors) for short:  0.28662420382165604
Accuracy (Remove LLM Errors) for long:  0.2802547770700637
Robustness (Removing Gap Examples) for short:  0.9555555555555556
Robustness (Removing Gap Examples) for long:  0.9629629629629629
Robustness (Removing LLM Errors & Gap Examples) for short:  1.0
Robustness (Removing LLM Errors & Gap Examples) for long:  0.9777777777777777


In [10]:
print('Knowledge-based Accuracy for short: ', metric_short_balanced.get_knowledge_accuracy())
print('Knowledge-based Accuracy for long: ', metric_long_balanced.get_knowledge_accuracy())

Knowledge-based Accuracy for short:  0.5263157894736842
Knowledge-based Accuracy for long:  0.543859649122807
