In [1]:
import os
import numpy as np
from collections import defaultdict

In [2]:
dataset_name = 'FB15k-237'
data_dirs = ['cpg_conve', 'cpg_minerva', 'plain_minerva']
metrics = ['hits_at_1', 'hits_at_3', 'hits_at_5', 'hits_at_10', 'mrr']

In [29]:
# generate the daset paths
def generate_data_paths(dataset_name, data_dirs, metrics, data_type='test'):
    dir_paths = defaultdict(lambda: [].copy())
    for data_dir in data_dirs:
        for metric_type in metrics:
            data_dir_path = os.path.join(os.getcwd(), 
                                         dataset_name, 
                                         data_dir, 
                                         '{}_relation_{}.txt'.format(data_type, metric_type))
            dir_paths[data_dir].append(data_dir_path)
    return dir_paths


def _write_data_to_file(file_path, data):
    if os.path.exists(file_path):
        append_write = 'a'
    else:
        append_write = 'w+'
    with open(file_path, append_write) as handle:
        handle.write(str(data) + "\n")
        

def correct_data_paths(data_dirs):
    for data_dir in data_dirs:
        data_corrected = []
        with open(data_dir, 'r') as handle:
            lines = handle.readlines()
            for idx in range(0, len(lines), 2):
                relation = lines[idx].strip()
                value = lines[idx+1].strip()
                data_corrected.append('{}\t{}'.format(relation, value))
        new_path = data_dir[:-4] + '_new.txt'
        with open(data_dir, 'w') as handle:
            for data in data_corrected:
                handle.write(data + '\n')
            _write_data_to_file(new_path, data)
                

def extract_relation_metrics(paths_dict):
    relation_metrics = defaultdict(lambda: {'hits_at_1': 0, 
                                            'hits_at_3': 0, 
                                            'hits_at_5': 0, 
                                            'hits_at_10': 0, 
                                            'mrr': 0}.copy())
    model_metrics = defaultdict(lambda: relation_metrics.copy())
    
    for model_type in paths_dict.keys():
        for path in paths_dict[model_type]:
            if 'test_relation' in path:
                path_type = path.split('/')[-1][14:-4]
            else:
                path_type = path.split('/')[-1][15:-4]
            with open(path, 'r') as handle:
                for line in handle:
                    relation, metric = line.strip().split('\t')
                    model_metrics[model_type][relation][path_type] += float(metric)
    return model_metrics


def cleanly_print_dict(current_object, padding='', precision=3):
    if  (not isinstance(list(current_object.values())[0], dict) and 
         not isinstance(list(current_object.values())[0], defaultdict)):
        for key, value in current_object.items():
            try:
                str_value = '{0:.{1}f}'.format(float(value), precision)
            except:
                str_value = str(value)
            line_to_print = padding + str(key) + '\t' + str_value
            print(line_to_print)
    else:
        for key, value in current_object.items():
            line_to_print = padding + str(key)
            print(line_to_print)
            cleanly_print_dict(value, padding=padding+'\t')
        
        
def sample_from_dict(dictionary, k=10):
    sampled_dict = {}
    relations = dictionary.keys()
    if len(relations) < k:
        k = len(relations)
    sampled_relations = np.random.choice(list(relations), k)
    for sampled_relation in sampled_relations:
        sampled_dict[sampled_relation] = dictionary[sampled_relation]
    return sampled_dict


def compute_ratio(source, target):
    return (target - source) / max(source, 1.)


def compute_metric_differences(model_metrics, source_model, target_model):
    source_model_metrics = model_metrics[source_model]
    target_model_metrics = model_metrics[target_model]
    comparison_metrics = {}
    for relation in source_model_metrics.keys():
        source_relation_metrics = source_model_metrics[relation]
        target_relation_metrics = target_model_metrics[relation]
        comparison_metrics[relation] = {}
        for metric_name in source_relation_metrics.keys():
            source_relation_metric_val = source_relation_metrics[metric_name]
            target_relation_metric_val = target_relation_metrics[metric_name]
            comparison_metrics[relation][metric_name] = compute_ratio(source_relation_metric_val, 
                                                                      target_relation_metric_val)
    return comparison_metrics
            

def reorder_metric_storage(metric_storage):
    reodered_storage = {}
    for relation, metrics in metric_storage.items():
        for metric_name, metric_value in metrics.items():
            if metric_name not in reodered_storage.keys():
                reodered_storage[metric_name] = {}
            reodered_storage[metric_name][relation] = metric_value
    return reodered_storage


def sort_dict_by_value(dictionary, decreasing=True):
    return sorted(dictionary.items(), key=lambda kv: kv[1], reverse=decreasing)


def get_metric_tails(metric2relation2value, tail_len=10):
    metric_tails_dict = {}
    for metric, relation_metric_values in metric2relation2value.items():
        metric_tails_dict[metric] = {}
        sorted_relation_values = sort_dict_by_value(relation_metric_values, decreasing=True)
        top_k = dict(sorted_relation_values[:tail_len])
        bottom_k = dict(sorted_relation_values[-tail_len:])
        metric_tails_dict[metric]['top_{}'.format(tail_len)] = top_k
        metric_tails_dict[metric]['bottom_{}'.format(tail_len)] = bottom_k
    return metric_tails_dict


def read_entities(dataset_name):
    path = os.path.join(os.getcwd(), dataset_name, 'entities.txt')
    entities = []
    with open(path, 'r') as handle:
        for line in handle:
            entity = line.strip()
            entities.append(entity)
    return entities

In [4]:
metric_dir_paths = generate_data_paths(dataset_name, data_dirs, metrics, data_type='test')
metric_dir_paths

defaultdict(<function __main__.generate_data_paths.<locals>.<lambda>>,
            {'cpg_conve': ['/Users/georgestoica/Desktop/Research/QA/qa_types/src/qa_cpg/temp/FB15k-237/cpg_conve/test_relation_hits_at_1.txt',
              '/Users/georgestoica/Desktop/Research/QA/qa_types/src/qa_cpg/temp/FB15k-237/cpg_conve/test_relation_hits_at_3.txt',
              '/Users/georgestoica/Desktop/Research/QA/qa_types/src/qa_cpg/temp/FB15k-237/cpg_conve/test_relation_hits_at_5.txt',
              '/Users/georgestoica/Desktop/Research/QA/qa_types/src/qa_cpg/temp/FB15k-237/cpg_conve/test_relation_hits_at_10.txt',
              '/Users/georgestoica/Desktop/Research/QA/qa_types/src/qa_cpg/temp/FB15k-237/cpg_conve/test_relation_mrr.txt'],
             'cpg_minerva': ['/Users/georgestoica/Desktop/Research/QA/qa_types/src/qa_cpg/temp/FB15k-237/cpg_minerva/test_relation_hits_at_1.txt',
              '/Users/georgestoica/Desktop/Research/QA/qa_types/src/qa_cpg/temp/FB15k-237/cpg_minerva/test_relation_hits_at

In [5]:
# if metrics are messed up run this
# correct_data_paths(metric_dir_paths['cpg_conve'])

In [6]:
model_metrics = extract_relation_metrics(metric_dir_paths)

In [7]:
sampled_dictionary = sample_from_dict(model_metrics)

In [30]:
cleanly_print_dict(sampled_dictionary)

plain_minerva
	/tv/tv_producer/programs_produced./tv/tv_producer_term/producer_type
		hits_at_5	0.955
		hits_at_3	0.955
		hits_at_10	0.955
		hits_at_1	0.864
		mrr	0.902
	/education/educational_degree/people_with_this_degree./education/education/student
		hits_at_5	0.000
		hits_at_3	0.000
		hits_at_10	0.000
		hits_at_1	0.000
		mrr	0.000
	/base/x2010fifaworldcupsouthafrica/world_cup_squad/current_world_cup_squad./base/x2010fifaworldcupsouthafrica/current_world_cup_squad/current_club
		hits_at_5	0.034
		hits_at_3	0.034
		hits_at_10	0.069
		hits_at_1	0.000
		mrr	0.022
	/film/film/dubbing_performances./film/dubbing_performance/actor
		hits_at_5	0.000
		hits_at_3	0.000
		hits_at_10	0.250
		hits_at_1	0.000
		mrr	0.034
	/award/award_winning_work/awards_won./award/award_honor/award
		hits_at_5	0.127
		hits_at_3	0.110
		hits_at_10	0.203
		hits_at_1	0.034
		mrr	0.094
	/influence/influence_node/influenced_by
		hits_at_5	0.098
		hits_at_3	0.081
		hits_at_10	0.140
		hits_at_1	0.043
		mrr	0.082
	/bus

		hits_at_1	0.182
		mrr	0.237
	/film/film/film_production_design_by
		hits_at_5	0.600
		hits_at_3	0.500
		hits_at_10	0.700
		hits_at_1	0.400
		mrr	0.481
	/common/topic/webpage./common/webpage/category
		hits_at_5	1.000
		hits_at_3	1.000
		hits_at_10	1.000
		hits_at_1	1.000
		mrr	1.000
	/education/university/fraternities_and_sororities
		hits_at_5	1.000
		hits_at_3	1.000
		hits_at_10	1.000
		hits_at_1	0.857
		mrr	0.929
	/award/award_ceremony/awards_presented./award/award_honor/award_winner
		hits_at_5	0.132
		hits_at_3	0.101
		hits_at_10	0.189
		hits_at_1	0.044
		mrr	0.096
	/tv/non_character_role/tv_regular_personal_appearances./tv/tv_regular_personal_appearance/person
		hits_at_5	0.286
		hits_at_3	0.286
		hits_at_10	0.429
		hits_at_1	0.000
		mrr	0.119
	/music/performance_role/guest_performances./music/recording_contribution/performance_role
		hits_at_5	0.000
		hits_at_3	0.000
		hits_at_10	0.000
		hits_at_1	0.000
		mrr	0.059
	/time/event/instance_of_recurring_event
		hits_at_5	0.818
		h

In [31]:
reoredered_model_metrics = reorder_metric_storage(model_metrics['cpg_conve'])

In [32]:
tail_model_metrics = get_metric_tails(reoredered_model_metrics, tail_len=10)

In [33]:
cleanly_print_dict(tail_model_metrics)

hits_at_1
	top_10
		/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency	1.000
		/tv/tv_producer/programs_produced./tv/tv_producer_term/producer_type	1.000
		/film/film/film_art_direction_by	1.000
		/education/educational_institution/campuses	1.000
		/education/university/domestic_tuition./measurement_unit/dated_money_value/currency	1.000
		/location/hud_foreclosure_area/estimated_number_of_mortgages./measurement_unit/dated_integer/source	1.000
		/base/petbreeds/city_with_dogs/top_breeds./base/petbreeds/dog_city_relationship/dog_breed	1.000
		/organization/endowed_organization/endowment./measurement_unit/dated_money_value/currency	1.000
		/education/university/local_tuition./measurement_unit/dated_money_value/currency	1.000
		/user/tsegaran/random/taxonomy_subject/entry./user/tsegaran/random/taxonomy_entry/taxonomy	1.000
	bottom_10
		/music/performance_role/guest_performances./music/recording_contribution/performance_role	0.000
		/people/cause_of_death/peo

In [34]:
cleanly_print_dict({'hits_at_1': tail_model_metrics['hits_at_1']})

hits_at_1
	top_10
		/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency	1.000
		/tv/tv_producer/programs_produced./tv/tv_producer_term/producer_type	1.000
		/film/film/film_art_direction_by	1.000
		/education/educational_institution/campuses	1.000
		/education/university/domestic_tuition./measurement_unit/dated_money_value/currency	1.000
		/location/hud_foreclosure_area/estimated_number_of_mortgages./measurement_unit/dated_integer/source	1.000
		/base/petbreeds/city_with_dogs/top_breeds./base/petbreeds/dog_city_relationship/dog_breed	1.000
		/organization/endowed_organization/endowment./measurement_unit/dated_money_value/currency	1.000
		/education/university/local_tuition./measurement_unit/dated_money_value/currency	1.000
		/user/tsegaran/random/taxonomy_subject/entry./user/tsegaran/random/taxonomy_entry/taxonomy	1.000
	bottom_10
		/music/performance_role/guest_performances./music/recording_contribution/performance_role	0.000
		/people/cause_of_death/peo

In [35]:
list(model_metrics['cpg_conve'].keys())

['/tv/tv_producer/programs_produced./tv/tv_producer_term/producer_type',
 '/education/educational_degree/people_with_this_degree./education/education/student',
 '/base/x2010fifaworldcupsouthafrica/world_cup_squad/current_world_cup_squad./base/x2010fifaworldcupsouthafrica/current_world_cup_squad/current_club',
 '/film/film/dubbing_performances./film/dubbing_performance/actor',
 '/award/award_winning_work/awards_won./award/award_honor/award',
 '/influence/influence_node/influenced_by',
 '/business/job_title/people_with_this_title./business/employment_tenure/company',
 '/base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/service_language',
 '/people/person/place_of_birth',
 '/award/award_category/category_of',
 '/base/popstra/celebrity/canoodled./base/popstra/canoodled/participant',
 '/film/film/personal_appearances./film/personal_film_appearance/person',
 '/user/tsegaran/random/taxonomy_subject/entry./user/tsegaran/random/taxonomy_entry/taxonomy',
 '/fil

In [36]:
minerva_metric_diffs = compute_metric_differences(model_metrics, 'plain_minerva', 'cpg_minerva')

In [37]:
minerva_metrics_reordered = reorder_metric_storage(minerva_metric_diffs)

In [38]:
minerva_tail_model_metrics = get_metric_tails(minerva_metrics_reordered, tail_len=10)

In [47]:
cleanly_print_dict(minerva_tail_model_metrics['mrr'])

top_10
	/travel/travel_destination/how_to_get_here./travel/transportation/mode_of_transportation	0.538
	/base/schemastaging/organization_extra/phone_number./base/schemastaging/phone_sandbox/contact_category	0.500
	/sports/sports_team/roster./baseball/baseball_roster_position/position	0.508
	/government/governmental_body/members./government/government_position_held/legislative_sessions	0.619
	/base/eating/practicer_of_diet/diet	0.818
	/film/person_or_entity_appearing_in_film/films./film/personal_film_appearance/type_of_appearance	0.985
	/film/film/distributors./film/film_film_distributor_relationship/region	0.417
	/olympics/olympic_sport/athletes./olympics/olympic_athlete_affiliation/olympics	0.401
	/base/petbreeds/city_with_dogs/top_breeds./base/petbreeds/dog_city_relationship/dog_breed	0.475
	/travel/travel_destination/climate./travel/travel_destination_monthly_climate/month	0.778
bottom_10
	/soccer/football_team/current_roster./soccer/football_roster_position/position	-0.150
	/film/f