In [1]:
import numpy as np
import pandas

In [2]:
# load entity2id
entity2id = {}
with open("data/Freebase/FB15k/entity2id.txt", 'r', encoding='utf-8') as f:
    lines = f.readlines()
for line in lines:
    e, idx = line.strip().split('\t')
    entity2id[e] = int(idx)

In [3]:
# load relation2id
relation2id = {}
with open("data/Freebase/FB15k/relation2id.txt", 'r', encoding='utf-8') as f:
    lines = f.readlines()
for line in lines:
    r, idx = line.strip().split('\t')
    relation2id[r] = int(idx)

In [4]:
# load type2id
type2id = {}
with open("data/Freebase/FB15kET/type2id.txt", 'r', encoding='utf-8') as f:
    lines = f.readlines()
for line in lines:
    t, idx = line.strip().split('\t')
    type2id[t] = int(idx)

In [5]:
# load training set of FB15k
train_triplet = []
with open('data/Freebase/FB15k/freebase_mtr100_mte100-train.txt', 'r') as f:
    for line in f.readlines():
        h, l, t = line.strip().split("\t")
        train_triplet.append((entity2id[h],relation2id[l],entity2id[t]))

In [6]:
# load training set of FB15kET
train_e2t = {}
pair_train = 0 
with open("data/Freebase/FB15kET/FB15k_Entity_Type_train.txt", 'r') as f:
    for line in f.readlines():
        pair_train +=1
        h, t = line.strip().split("\t")
        if entity2id[h] not in train_e2t:
            train_e2t[entity2id[h]] = []
        if type2id[t] not in train_e2t[entity2id[h]]:
            train_e2t[entity2id[h]].append(type2id[t])

In [7]:
# load validation set of FB15kET
dev_e2t = {}
pair_dev =0
with open("data/Freebase/FB15kET/FB15k_Entity_Type_valid_clean.txt", 'r') as f:
    for line in f.readlines():
        pair_dev+=1
        h, t = line.strip().split("\t")
        if entity2id[h] not in dev_e2t:
            dev_e2t[entity2id[h]] = []
        if type2id[t] not in dev_e2t[entity2id[h]]:
            dev_e2t[entity2id[h]].append(type2id[t])

In [8]:
# load test set of FB15kET
test_e2t = {}
pair_test =0
with open("data/Freebase/FB15kET/FB15k_Entity_Type_test_clean.txt", 'r') as f:
    for line in f.readlines():
        pair_test+=1
        h, t = line.strip().split("\t")
        if entity2id[h] not in test_e2t:
            test_e2t[entity2id[h]] = []
        test_e2t[entity2id[h]].append(type2id[t])

In [9]:
print("FB15k:",len(train_triplet),"triples")
print("FB15kET:",pair_train,"train pairs, ", pair_dev,"valid pairs,", pair_test, "test pairs")

FB15k: 483142 triples
FB15kET: 136618 train pairs,  15749 valid pairs, 15780 test pairs


In [10]:
# initializa M_e2r and M_r2t
relation_head_type =np.zeros((len(relation2id), len(type2id)))
relation_tail_type =np.zeros((len(relation2id), len(type2id)))
head_link_relation =np.zeros((len(entity2id), len(relation2id)))
tail_link_relation =np.zeros((len(entity2id), len(relation2id)))

In [11]:
for triplet in train_triplet:
    h, r, t = triplet
    # get ETIs of h
    if h in train_e2t:
        head_type = train_e2t[h]
        relation_head_type[r][head_type] += 1
    
    if t in train_e2t:
        tail_type = train_e2t[t] 
        relation_tail_type[r][tail_type] += 1
        
    head_link_relation[h][r] += 1
    tail_link_relation[t][r] += 1

In [12]:
entity2relation = np.concatenate((head_link_relation, tail_link_relation), axis=1)
relation2type = np.concatenate((relation_head_type, relation_tail_type), axis=0)

In [13]:
# calculate M
entity2type = np.matmul(entity2relation, relation2type)
e2t_arg = np.argsort(-entity2type)

In [14]:
# evaluation
def evaluation(data_name='dev'):
    mr = mrr = hit10 = hit3 =hit1 = 0
    fmr = fmrr = fhit10 = fhit3 = fhit1 = 0
    if data_name == 'dev':
        data_set = dev_e2t.copy()
    elif data_name == 'test':
        data_set = test_e2t.copy()
    for entity_ in data_set:
        type_arg = np.argsort(-entity2type[entity_])
        test_rank_list = []
        train_rank_list = []
        valid_rank_list = []
        if entity_ in test_e2t:
            for type_label in test_e2t[entity_]:
                rank = (type_arg==type_label).nonzero()[0].item()+1
                test_rank_list.append(rank)
        if entity_ in dev_e2t:
            for type_label in dev_e2t[entity_]:
                rank = (type_arg==type_label).nonzero()[0].item()+1
                valid_rank_list.append(rank)
        if entity_ in train_e2t:
            for type_label in train_e2t[entity_]:
                rank = (type_arg==type_label).nonzero()[0].item()+1
                train_rank_list.append(rank)
        rank_list = train_rank_list + test_rank_list + valid_rank_list
        rank_list.sort()
        
        if data_name == 'dev':
            target_rank_list = valid_rank_list.copy()
        elif data_name == 'test':
            target_rank_list = test_rank_list.copy()
        
        for i, rank in enumerate(target_rank_list):
            #rank is ’raw‘ rank
            #raw-index is the rank of all correct
            #rank - raw_index is filt rank
            raw_index = rank_list.index(rank)
            frank = rank - raw_index

            #if rank == raw_index  frank shoud be 1
            if frank <= 0:
                frank = 1

            fmr += frank
            fmrr += 1.0/frank
            if frank <=10:
                fhit10 += 1
            if frank <=3:
                fhit3 += 1
            if frank <= 1:
                fhit1 += 1
    
    num_of_e2t = 0
    for i in data_set:
        num_of_e2t += len(data_set[i])
        
    return fmrr/num_of_e2t, fhit1/num_of_e2t, fhit3/num_of_e2t, fhit10/num_of_e2t

In [15]:
dev_evaluation = evaluation('dev')
test_evaluation = evaluation('test')
        
print("results of validation：",dev_evaluation)
print("results of test：",test_evaluation)

results of validation： (0.49361499121152574, 0.3999619023430059, 0.5352085846720427, 0.6762334116451838)
results of test： (0.4957710207982826, 0.4007604562737643, 0.5370088719898606, 0.6818124207858048)
