In [3]:
import os
import numpy as np
import torch
from collections import defaultdict as ddict
import dgl
from tqdm import tqdm
import pickle as pkl
from collections import Counter

In [34]:
data_path = './FB15k-237/'
# data_path = './wn18rr'
# data_path = './codex-l'
# data_path = './YAGO3-10'

In [35]:
def get_ent_rel_map(data_path):
    with open(os.path.join(data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    return entity2id, relation2id


def read_triple(data_path, entity2id, relation2id):
    train_triples = []
    with open(os.path.join(data_path, 'train.txt')) as fin:
        for line in fin:
            h, r, t = line.strip().split('\t')
            train_triples.append((entity2id[h], relation2id[r], entity2id[t]))

    valid_triples = []
    with open(os.path.join(data_path, 'valid.txt')) as fin:
        for line in fin:
            h, r, t = line.strip().split('\t')
            valid_triples.append((entity2id[h], relation2id[r], entity2id[t]))

    test_triples = []
    with open(os.path.join(data_path, 'test.txt')) as fin:
        for line in fin:
            h, r, t = line.strip().split('\t')
            test_triples.append((entity2id[h], relation2id[r], entity2id[t]))

    return train_triples, valid_triples, test_triples

def get_train_g(train_triples, num_ent):
    train_triples = torch.LongTensor(train_triples)
    num_tri = train_triples.shape[0]
    g = dgl.graph((train_triples[:, 0].T, train_triples[:, 2].T), num_nodes=num_ent)
    g.edata['rel'] = train_triples[:, 1].T
    
    return g

In [36]:
entity2id, relation2id = get_ent_rel_map(data_path)
num_ent = len(entity2id)
num_rel = len(relation2id)
print('#ent:', num_ent)
print('#rel:', num_rel)

train_triples, valid_triples, test_triples = read_triple(data_path, entity2id, relation2id)
print('#train:', len(train_triples))
print('#valid:', len(valid_triples))
print('#test:', len(test_triples))

#ent: 40559
#rel: 11
#train: 86835
#valid: 2824
#test: 2924


In [37]:
train_g = get_train_g(train_triples, num_ent)

## relational feature for entities

In [38]:
ent_rel_feat = torch.zeros(num_ent, num_rel*2)

for e in tqdm(range(num_ent)):
    out_eid = train_g.out_edges(torch.tensor([e]), form='eid')
    in_eid = train_g.in_edges(torch.tensor([e]), form='eid')
    
    for r in train_g.edata['rel'][out_eid]:
        ent_rel_feat[e, r] += 1
    for r in train_g.edata['rel'][in_eid]:
        ent_rel_feat[e, r + num_rel] += 1

100%|██████████| 40559/40559 [03:06<00:00, 217.21it/s]


In [39]:
pkl.dump(ent_rel_feat, open(os.path.join(data_path, 'ent_rel_feat.pkl'), 'wb'))

## random reserved entities

In [40]:
res_ent_ratio = 0.1
ratio_str = '0p1'

# random select reserved entities
res_ent_map = \
    torch.unique(torch.tensor(np.random.choice(np.arange(num_ent), int(num_ent*res_ent_ratio), replace=False)))

## topk reserved entities for all entities

In [41]:
ent_rel_feat = pkl.load(open(os.path.join(data_path, 'ent_rel_feat.pkl'), 'rb'))

In [42]:
ent_rel_feat_norm = ent_rel_feat / (torch.norm(ent_rel_feat, dim=-1).reshape(-1, 1) + 1e-6)
ent_rel_feat = ent_rel_feat_norm[res_ent_map]
ent_sim = torch.mm(ent_rel_feat_norm, ent_rel_feat.T)

In [43]:
topk_sim, topk_idx = torch.topk(ent_sim, 100, dim=-1)

In [44]:
pkl.dump({'res_ent_map': res_ent_map, 
          'topk_sim': topk_sim,
          'topk_idx': topk_idx
         }, 
         open(os.path.join(data_path, f'res_ent_{ratio_str}.pkl'), 'wb'))