In [2]:
import json
import sys
import time
from tqdm import tqdm
sys.path.append('..')
from grammar.eval.match import SemanticsMatch


def judge_retrieval_result(result):
    result['true_document_ids'] = set([int(r) for r in result['true_document_ids']])
    result['retrieved_document_ids'] = set([int(r) for r in result['retrieved_document_ids']])
    result['retrieval_judgement'] = 1 if result['true_document_ids'] == result['retrieved_document_ids'] else 0
    return result

def judge_rag_result(result, semnatics_match):
    if 'judgement' in result:
        return result
    if root_dir=='spider_closed' and result['retrieval_judgement'] == 0:
        result["judgement"] = "Incorrect"
    else:
        result["judgement"] = semnatics_match.generate((result["query"], result["answer"], result["gpt_response"][0]), verbose=True)[0]
    return result

def get_results(linguistic_attr, root_dir):
    
    file_path = f'{root_dir}/eval_results/results_{linguistic_attr}.json'
    with open(file_path) as f:
        results = json.load(f)
    num_retrieval_failure = sum([result['retrieval_judgement']==0 for result in results])
    print(f"Retrieval failed in {num_retrieval_failure} out of {len(results)} examples")

    semnatics_match = SemanticsMatch.from_file(root_dir=root_dir, verbalize_attrs=linguistic_attr)
    for result in tqdm(results):
        # sleep for 20 seconds after 9 examples
        # if results.index(result) % 9 == 0 and results.index(result) != 0:
        #     print("Sleeping for 20 seconds")
        #     time.sleep(20)
        #     print("Waking up")
        result = judge_rag_result(result, semnatics_match)
    num_rag_failure = sum([result['judgement']=="Incorrect" for result in results])
    print(f"RAG failed in {num_rag_failure} out of {len(results)} examples")

    return results


root_dir = 'spider_closed'

results_short = get_results( 'short', root_dir)
results_long = get_results( 'long', root_dir)



Retrieval failed in 293 out of 426 examples


100%|██████████| 426/426 [00:00<00:00, 537375.49it/s]


The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to gen

100%|██████████| 430/430 [00:00<00:00, 637747.78it/s]

The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to gen




In [7]:
# from grammar.eval.metric import MetricForHypothesis
import json
from copy import deepcopy
from typing import Union, List, Dict, Tuple
from functools import lru_cache
class Metric:
    def __init__(self, results: List[Dict]):
        self.results = results
        self.result = [self._add_retrieval_judgement(example) for example in self.results if example['query_tag'] in self.non_robust_tags and example['judgement'] != 'Correct']

    @property
    def competent_tags(self):
        valid_tags = [result['query_tag'] for result in self.results if result['judgement'] == 'Correct']
        return set(valid_tags)
    
    @property
    def gap_tags(self):
        return set(self.all_tags) - set(self.competent_tags)
    
    @property
    @lru_cache(maxsize=None)
    def tag_to_examples(self):
        """ Tag to examples
        """
        tag_to_examples = {}
        for result in self.results:
            tag = result['query_tag']
            if tag not in tag_to_examples:
                tag_to_examples[tag] = []
            tag_to_examples[tag].append(result)
        return tag_to_examples
    
    @property
    def non_robust_tags(self):
        """ Tag non-robust groups or potentially competent groups where there exists at least one example that is not correct and at least one example is correct.
        """
        
        tag_to_examples = self.tag_to_examples
        non_robust_tags = set()
        for tag in self.competent_tags:
            examples = tag_to_examples[tag]
            for example in examples:
                if example['judgement'] != 'Correct': # at least one example is not correct
                    non_robust_tags.add(tag)   
                    break
        return non_robust_tags
    
    def get_incorrect_queries_in_non_robust_group(self):
        return {tag: [example['query'] for example in examples if example['judgement'] != 'Correct'] for tag, examples in self.non_robust_tags.items()}
    
    def get_correct_queries(self):
        return {tag: [example['query'] for example in examples if example['judgement'] == 'Correct'] for tag, examples in self.tag_to_examples.items()}
    
    def get_correct_incorrect_examples(self):
        incorrect_examples = self.get_incorrect_queries_in_non_robust_group()
        correct_examples = self.get_correct_queries()
        return {example: correct_examples[tag] for tag in incorrect_examples for example in incorrect_examples[tag]}

    @property
    def all_tags(self):
        tags = [result['query_tag'] for result in self.results]
        return set(tags)
        
    def get_best_case_accuracy(self):
        return len(self.competent_tags)/len(self.all_tags)
    
    def get_accuracy(self, results=None, for_retrieval=False, average_for_each_domain=False, average_all=False):
        num_correct = 0
        if results is None:
            results = self.results
            
        judgement_key = 'retrieval_judgement' if for_retrieval else 'judgement'
        for result in results:
            if result[judgement_key] == 'Correct':
                num_correct += 1
        return num_correct/len(results)

    def get_robustness(self, for_retrieval=False):
        results_for_competent_tags = [result for result in self.results if result['query_tag'] in self.competent_tags]
        return self.get_accuracy(results=results_for_competent_tags, for_retrieval=for_retrieval)

    def get_num_total_correct(self) ->  Dict[str, Tuple[int, int]]:
        """  Returns the number of total, correct answers for each query logic group

        Returns:
            dict:  {tag: (num_correct, num_total)}
        """
        def get_num_total_correct_for_each_qlogic(results, tag):
            # can be used to calculate in-group accuracy
            num_correct = 0
            num_total = 0
            for result in results:
                if result['query_tag'] == tag:
                    if result['judgement'] == 'Correct':
                        num_correct += 1
                    num_total += 1
            return num_correct, num_total
        return {tag: get_num_total_correct_for_each_qlogic(self.results, tag) for tag in self.all_tags}

    def _add_retrieval_judgement(self, example, method='context_comparison'):
        """ Judge the retrieval result via context comparsion: comparing the document index with the correctly predicted document index."""
        
        if 'retrieval_judgement' in example:
            return example
        elif method == 'context_comparison':
            assert example['query_tag'] in self.non_robust_tags, "This context-comparison method is only for an example in non-robust groups"
            example = deepcopy(example)
            examples = self.tag_to_examples[example['query_tag']]
            examples_with_correct_prediction = [example for example in examples if example['judgement'] == 'Correct']
            for example_correct in examples_with_correct_prediction:
                if example_correct['retrieved_documents_id'] == example['retrieved_documents_id']: # this is not the retrieval's fault
                    example['retrieval_judgement'] = 'Correct'
            return example
        
        else:
            raise Exception("So far, we only have context_comparison method")

class MetricForHypothesis():
    def __init__(self, results: Dict[str, List[Dict]]):
        self.metrics = {domain_name: Metric(results[domain_name]) for domain_name in results}

    def mutual_gap_tags(self):
        domains_to_gap_tags: Dict[str, set] = {domain_name: metric.gap_tags for domain_name, metric in self.metrics.items()}
        # return overlap tags across all the domains
        return set.intersection(*[tags for tags in domains_to_gap_tags.values()])
    
    def print_data_stat(self):
        """ Print data statistics
        """
        for domain_name, metric in self.metrics.items():
            print("======", domain_name, "======")
            num_toal_correct_for_each_group: Dict[str, Tuple[int, int]] = metric.get_num_total_correct()
            # ensure that gap tags are in the domain-specific cluster
            all_tags_for_domain = list(num_toal_correct_for_each_group.keys())
            gap_tags_for_domain = [tag for tag in all_tags_for_domain if tag in self.mutual_gap_tags]
            # print data statistics
            print_data_stat_by_groups(num_toal_correct_for_each_group, gap_tags_for_domain)
            print("\n")

    # def get_correct_retrieved_context(self):

    #     self.tag_to_examples

    #     tag_to_retrieved_context_dict = {}
    #     for domain_name in self.results_for_retrieval:
    #         tag_to_retrieved_context_dict[domain_name] = {}
    #         for tag in all_tags[domain_name]:
    #             tag_to_retrieved_context_dict[domain_name][tag] = []
    #             for result in queries_answer_pairs_dict[domain_name]:
    #                 if result['query_tag'] == tag and result['judgement'] == 'Correct':
    #                     retrieved_documents = result['retrieved_documents'].replace("### Question:\n"+result['query'], '')
    #                     query = result['query']
    #                     answer = tuple(result['answer'])
    #                     tag_to_retrieved_context_dict[domain_name][tag].append((query, answer, retrieved_documents))
        
    #     return tag_to_retrieved_context_dict


def print_data_stat_by_groups(group_details: Dict[str, Tuple[int, int]], tags_for_gap_groups: List ):
    """ Print data statistics by groups

    Args:
        group_details (Dict[str, Tuple[int, int]] ): Dictionary containing (correct, total) for each group
        tags_for_gap_groups (List): List of indices for gap groups. Gap groups have to be identified by considering both hypothesis and counter-hypothesis clusters.
    """
    
    # gap groups
    num_gap_groups = len(tags_for_gap_groups)
    num_gap_examples = sum([group_details[tag][1] for tag in tags_for_gap_groups])
    print(f"Gap groups: {num_gap_groups} Groups with {num_gap_examples} Examples")

    # non-gap groups
    all_tags = list(group_details.keys())
    tags_no_gap_groups = [tag for tag in all_tags if tag not in tags_for_gap_groups]
    details_no_gap_groups = [group_details[tag] for tag in tags_no_gap_groups]

    # robust groups
    num_robust_groups = len([1 for correct, total in details_no_gap_groups if total==correct and correct!=0])
    num_robust_examples = sum([total for correct, total in details_no_gap_groups if total==correct and correct!=0]) 
    num_robust_correct_examples = sum([correct for correct, total in details_no_gap_groups if total==correct and correct!=0])
    print(f"Roubst groups (# of groups/examples/correct examples): {num_robust_groups} / {num_robust_examples} / {num_robust_correct_examples}")

    # non-robust groups
    num_non_robust_groups = len([1 for correct, total in details_no_gap_groups if correct!=0 and correct!=total])
    num_non_robust_examples = sum([total for correct, total in details_no_gap_groups if correct!=0 and correct!=total])
    num_non_robust_correct_examples = sum([correct for correct, total in details_no_gap_groups if correct!=0 and correct!=total])
    print(f"Non-robust groups (# of groups/examples/correct examples): {num_non_robust_groups} / {num_non_robust_examples} / {num_non_robust_correct_examples}")

    # total
    total_examples = sum([total for correct, total in group_details.values()])
    print(f"Total number of groups: {len(group_details)} groups with {total_examples} Examples")
    num_correct_examples = sum([correct for correct, total in group_details.values()])
    assert num_correct_examples ==  num_non_robust_correct_examples + num_robust_correct_examples

    # accuracy and robustness
    accuracy = num_correct_examples / total_examples
    robustness = (num_robust_examples+num_non_robust_correct_examples) / (num_robust_examples + num_non_robust_examples)
    print("\tAccuracy:", round(accuracy, 2), f'({num_correct_examples} / {total_examples})')
    print("\tRobustness:", round(robustness, 2), f'({num_robust_examples+num_non_robust_correct_examples} / {num_robust_examples + num_non_robust_examples})')


def print_data_stat_by_groups_naive(group_details: List, print_details=False ):
    """ Print data statistics by groups

    Args:
        group_details (List): List of tuples containing (correct, total) for each group
        print_details (bool, optional): Print details. Defaults to False.
    """
    tags = list(range(len(group_details)))

    if print_details:
        print(f"Index for Query Logic: (Correct, Total) in each group:")
        for i, (correct, total) in zip(tags, group_details):
            # tags_for_gap_groups
            group_type = ''
            if correct == 0:
                group_type = " => gap group" 
            elif correct==total:
                group_type = '' #" => Robust Group"
            else:
                group_type = "=> Non-robust Group"
            print(f"\tGroup {i}: ({correct}, {total}){group_type}")
    else:
        # gap groups
        num_gap_groups = len([total for correct, total in group_details if correct == 0])
        num_gap_examples = sum([total for correct, total in group_details if correct == 0])
        print(f"Gap groups: {num_gap_groups} Groups with {num_gap_examples} Examples")

        # robust groups
        num_robust_groups = len([1 for correct, total in group_details if total==correct and correct!=0])
        num_robust_examples = sum([total for correct, total in group_details if total==correct and correct!=0]) 
        num_robust_correct_examples = sum([correct for correct, total in group_details if total==correct and correct!=0])
        print(f"Roubst groups: {num_robust_groups} Groups with {num_robust_examples} Examples")

        # non-robust groups
        num_non_robust_groups = len([1 for correct, total in group_details if correct!=0 and correct!=total])
        num_non_robust_examples = sum([total for correct, total in group_details if correct!=0 and correct!=total])
        num_non_robust_correct_examples = sum([correct for correct, total in group_details if correct!=0 and correct!=total])
        print(f"Non-robust groups: {num_non_robust_groups} Groups with {num_non_robust_examples} Examples")

        # total
        total_examples = sum([total for correct, total in group_details])
        print(f"Total number of groups: {len(group_details)} groups with {total_examples} Examples")
        num_correct_examples = sum([correct for correct, total in group_details])
        assert num_correct_examples ==  num_non_robust_correct_examples + num_robust_correct_examples

        # accuracy and robustness
        accuracy = num_correct_examples / total_examples
        robustness = (num_robust_examples+num_non_robust_correct_examples) / (num_robust_examples + num_non_robust_examples)
        print("\tAccuracy:", round(accuracy, 2), f'({num_correct_examples} / {total_examples})')
        print("\tRobustness:", round(robustness, 2), f'({num_robust_examples+num_non_robust_correct_examples} / {num_robust_examples + num_non_robust_examples})')

In [10]:
# long vs short
metric_short = Metric(results_short)
metric_long = Metric(results_long)
print('Accuracy for short: ', metric_short.get_accuracy())
print('Accuracy for long: ', metric_long.get_accuracy())

print('Robustness for short: ', metric_short.get_robustness())
print('Robustness for long: ', metric_long.get_robustness())

metric = MetricForHypothesis({"long": results_long, "short": results_short})

0.25586854460093894
0.2837209302325581
0.5797872340425532
0.4357142857142857
