## 1. Inference few samples to view ranking results

In [26]:
import json
import os

import numpy as np
import torch

from torch.utils.data import DataLoader

from model import KGEModel

from dataloader import TrainDataset
from dataloader import BidirectionalOneShotIterator

CKPT_PATH = "../models/RotatE_Wiki15k_0"
with open(os.path.join(CKPT_PATH, 'config.json'), 'r') as fjson:
    configs = json.load(fjson)
class Args():
    def __init__(self, **entries):
        self.__dict__.update(entries)
args = Args(**configs)

    
GPU_DEVICE=0
DATA_PATH = "../data/Wiki15k/"

In [27]:
def load_data(data_path):
    
    def read_triple(file_path, entity2id, relation2id):
        triples = []
        with open(file_path) as fin:
            for line in fin:
                h, r, t = line.strip().split('\t')
                triples.append((entity2id[h], relation2id[r], entity2id[t]))
        return triples

    with open(os.path.join(data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)
            
    train_triples = read_triple(os.path.join(data_path, 'train.txt'), entity2id, relation2id)
    print('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(data_path, 'valid.txt'), entity2id, relation2id)
    print('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(data_path, 'test.txt'), entity2id, relation2id)
    print('#test: %d' % len(test_triples))
    all_true_triples = train_triples + valid_triples + test_triples
            
    infer_triples = read_triple(os.path.join(data_path, 'infer.txt'), entity2id, relation2id)
    print('#infer: %d' % len(infer_triples))
    
    return infer_triples, all_true_triples, entity2id, relation2id

# _ = load_data("../data/Wiki15k/")

In [28]:
def load_model(ckpt_path, model_name, nent, nrel, hdim, gamma, de=True, dr=False, use_cuda=True):
    
    kge_model = KGEModel(
        model_name=model_name,
        nentity=nent,
        nrelation=nrel,
        hidden_dim=hdim,
        gamma=gamma,
        double_entity_embedding=de,
        double_relation_embedding=dr
    )
    
    print('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        print('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    device = torch.device("cuda",GPU_DEVICE) if use_cuda else torch.device("cpu")
    kge_model.to(device)
    
    # Restore model from checkpoint directory
    print('Loading checkpoint %s...' % ckpt_path)
    checkpoint = torch.load(os.path.join(ckpt_path, 'checkpoint'))
    kge_model.load_state_dict(checkpoint['model_state_dict'])
    
    return kge_model

In [29]:
def inference(infer_triples, all_true_triples, model, id2ent, id2rel, args, save_path="results/"):
    # save ranking results
    save_rt = []
    save_f = os.path.join(save_path, "ranking_results.json")
    
    metrics, ranking_rt = model.test_step(model, infer_triples, all_true_triples, args)
    for metric in metrics:
        print('%s %s %f' % ("test", metric, metrics[metric]))
    # get ranking results
    for item in ranking_rt:
        p_h, p_r, p_t = id2ent[item["positive_sample"][0]], id2rel[item["positive_sample"][1]], id2ent[item["positive_sample"][2]]
        topk_rank_ents = [id2ent[eid] for eid in item["topk_rank_idxs"]]
        save_rt.append(
            {"mode": item["mode"],
            "positive_triple": [p_h, p_r, p_t],
            "topk_rank_ents": topk_rank_ents,
            "topk_rank_scores": item["topk_rank_scores"]
            }
        )
    with open(save_f, "w") as f:
        json.dump(save_rt, f, indent=4, ensure_ascii=False)

In [30]:
# run
# load dataset
infer_triples, all_true_triples, entity2id, relation2id = load_data(DATA_PATH)
id2ent = {v:k for k,v in entity2id.items()}
id2rel = {v:k for k,v in relation2id.items()}
# load model
kge_model = load_model(CKPT_PATH, args.model, len(entity2id), len(relation2id), args.hidden_dim, args.gamma, args.double_entity_embedding, args.double_relation_embedding, use_cuda=True)
# inference
inference(infer_triples, all_true_triples, kge_model, id2ent, id2rel, args, save_path="../results/")

#train: 159036
#valid: 8727
#test: 8761
#infer: 1000
Model Parameter Configuration:
Parameter gamma: torch.Size([1]), require_grad = False
Parameter embedding_range: torch.Size([1]), require_grad = False
Parameter entity_embedding: torch.Size([15817, 2000]), require_grad = True
Parameter relation_embedding: torch.Size([182, 1000]), require_grad = True
Loading checkpoint ../models/RotatE_Wiki15k_0...
test MRR 0.408852
test MR 364.835000
test HITS@1 0.328500
test HITS@3 0.453500
test HITS@10 0.555000


## 2. Test trained embeddings with cosine similarity

In [31]:
import numpy as np
import copy

ent_embed_file = "../models/RotatE_Wiki15k_0/entity_embedding.npy"
rel_embed_file = "../models/RotatE_Wiki15k_0/relation_embedding.npy"

ent_embeddings = np.load(ent_embed_file)
rel_embeddings = np.load(rel_embed_file)
print(ent_embeddings.shape)
print(rel_embeddings.shape)

data_path = "../data/Wiki15k/"
with open(os.path.join(data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

(15817, 2000)
(182, 1000)


In [32]:
def find_topk_sim_ents(cur_ent, ent2id, ent_embeddings, topk=10, exclude_self=True):
    
    cur_ent_id = ent2id[cur_ent]
    cur_embeddings = [ent_embeddings[cur_ent_id]] * len(ent_embeddings)
    
    # get scores
    coss = np.multiply(cur_embeddings, ent_embeddings).sum(-1) / \
            np.multiply(np.linalg.norm(cur_embeddings, axis=1), np.linalg.norm(ent_embeddings, axis=1))
    # sort
    argsorts = np.argsort(coss)[::-1]
    
    topk_idxs = []
    topk_ents = []
    for i in range(len(argsorts)):
        if exclude_self and argsorts[i] != cur_ent_id:
            continue
        topk_idxs.append(argsorts[i])
        for k,v in ent2id.items():
            if v == argsorts[i]:
                topk_ents.append(k)
                break
        if len(topk_idxs) == topk:
            break
    return topk_idxs, topk_ents

In [33]:
# specify ent for testing
test_ents = ["Q107008", "Q1701293", "Q2702789"]
for cur_ent in test_ents:
    topk_idxs, topk_ents = find_topk_sim_ents(cur_ent, entity2id, ent_embeddings, 11, exclude_self=False)
    print(topk_ents)

['Q107008', 'Q49575', 'Q1768', 'Q281034', 'Q4030', 'Q229430', 'Q107432', 'Q354508', 'Q82222', 'Q109053', 'Q104358']
['Q1701293', 'Q1494959', 'Q1705192', 'Q539757', 'Q7344802', 'Q3021483', 'Q7816353', 'Q7535711', 'Q1767860', 'Q5300549', 'Q2130828']
['Q2702789', 'Q2060840', 'Q1584317', 'Q511731', 'Q1321379', 'Q186941', 'Q1404450', 'Q2450848', 'Q168383', 'Q512858', 'Q192557']
