In [7]:
import json
import random

from numpy import mean
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RepeatedStratifiedKFold

from util.parse import generate_dictionaries, generate_id_dict, encode_triples

In [19]:
from config import FB15K

dataset = "FB15K"

with open(FB15K().PARSED_MODELS_PATH, 'r') as file:
    fb15k_models = json.load(file)

# models = armstrader_models if dataset.identifier == "armstrader" else vcslam_models
models = fb15k_models
template = fb15k_models
triples = []
for model in models[:]:
    for triple in model:
        triples.append(tuple(triple))

classes, predicates = generate_dictionaries(triples)
classes_mapping = generate_id_dict(classes)
predicates_mapping = generate_id_dict(predicates)


In [9]:
len(triples)

272115

In [10]:


# triples = modrel.reduce_relations(triples, targets)

#test_triples = random.sample(triples, floor(len(triples) / 10))
# test_triples = [triple for triple in test_triples if triple[0] == 'http://schema.org/Offer' and triple[2] == str(XSD.string)]



In [11]:
limit = 20000
test = 2000

Statistics Recommender Baseline

In [12]:
from modelextension.statistics_recommender import StatisticsRecommender as SR
from util.metrics import calc_hits_mrr
from util.utilities import prepare_data

X, y = prepare_data(template, c_map=classes_mapping,
                    p_map=predicates_mapping, shuffle=False, multiply=1)

encoded_triples = encode_triples(triples, classes_mapping, predicates_mapping)
sr = SR(triples=encoded_triples[:limit-test])

In [13]:

pred = sr.predict_links(X[test:limit])
hits_mrr = calc_hits_mrr( pred, y[test:limit])
print(f'SR {dataset} MRR {hits_mrr["mrr"]}, Hits@1 {hits_mrr["hits@1"]}, Hits@3 {hits_mrr["hits@3"]}')

SR FB15K MRR 0.9505958333333334, Hits@1 0.9048, Hits@3 0.99845


Train RFC

In [14]:
%%capture
from util.metrics import calc_hits_mrr
from util.utilities import prepare_data

from sklearn.model_selection import cross_validate

cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=1, random_state=1)

X, y = prepare_data(models = template, multiply=1, c_map=classes_mapping,
                    p_map=predicates_mapping)
rfc = RandomForestClassifier(n_estimators=10, max_depth=20)
rfc_results = cross_validate(rfc, X[:], y[:].reshape(-1), scoring='accuracy', cv=cv, n_jobs=-1,
                             error_score='raise', return_estimator=True)

In [15]:
mrrs = []
hits_1 = []
hits_3 = []
for est in rfc_results['estimator']:
    pred = est.predict_proba(X[:limit])
    hits_mrr = calc_hits_mrr(pred, y)
    hits_1.append(hits_mrr['hits@1'])
    hits_3.append(hits_mrr['hits@3'])
    mrrs.append(hits_mrr['mrr'])

print(f'RFC {dataset} 10 20 | {mean(mrrs):.3f} {mean(hits_1):.3f} {mean(hits_3):.3f}')


RFC FB15K 10 20 | 0.590 0.483 0.643


Train RGCN

In [18]:
import torch
from linkprediction import utils
from linkprediction.rgcn import LinkPredict, node_norm_to_edge_norm

train_triples = triples[:limit]
test_triples = random.sample(triples[:limit], int(len(triples[:limit])/10))
print("training triples size", len(train_triples))

train_data = encode_triples(train_triples, classes_mapping, predicates_mapping)
valid_data = encode_triples(test_triples, classes_mapping, predicates_mapping)

# load graph data
num_nodes = len(classes)
num_rels = len(predicates)

# create model
rgcn_model = LinkPredict(in_dim=num_nodes,
                         h_dim=100,
                         num_rels=num_rels,
                         num_bases=10,
                         num_hidden_layers=2,
                         dropout=0.1,
                         use_cuda=False,
                         reg_param=0.01)

# validation and testing triplets
valid_data = torch.LongTensor(valid_data)
test_data = torch.LongTensor(valid_data)

# build test graph
test_graph, test_rel, test_norm = utils.build_test_graph(
    num_nodes, num_rels, train_data)
test_deg = test_graph.in_degrees(
    range(test_graph.number_of_nodes())).float().view(-1, 1)
test_node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1)
test_rel = torch.from_numpy(test_rel)
test_norm = node_norm_to_edge_norm(test_graph, torch.from_numpy(test_norm).view(-1, 1))

# build adj list and calculate degrees for sampling
adj_list, degrees = utils.get_adj_and_degrees(num_nodes, train_data)

# optimizer
optimizer = torch.optim.Adam(rgcn_model.parameters(), lr=0.001)

forward_time = []
backward_time = []

# training loop
# print("start training...")

epoch = 0
best_mrr = 0
best_hits3 = 0
checkpoint = None
while True:
    rgcn_model.train()
    epoch += 1

    # perform edge neighborhood sampling to generate training graph and data
    g, node_id, edge_type, node_norm, data, labels = \
        utils.generate_sampled_graph_and_labels(
            train_data, 20, 0.3,
            num_rels, adj_list, degrees, 5,
            "neighbor")
    # print("Done edge sampling")

    # set node/edge feature
    node_id = torch.from_numpy(node_id).view(-1, 1).long()
    edge_type = torch.from_numpy(edge_type)
    edge_norm = node_norm_to_edge_norm(g, torch.from_numpy(node_norm).view(-1, 1))
    data, labels = torch.from_numpy(data), torch.from_numpy(labels)
    deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1)

    embed = rgcn_model(g, node_id, edge_type, edge_norm)
    loss = rgcn_model.get_loss(g, embed, data, labels)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(rgcn_model.parameters(), 1.0)  # clip gradients
    optimizer.step()
    optimizer.zero_grad()

    # validation
    if epoch % 100 == 0:
        rgcn_model.eval()
        # print("start eval")
        embed = rgcn_model(test_graph, test_node_id, test_rel, test_norm)
        results = utils.calc_mrr(embed, rgcn_model.w_relation, torch.LongTensor(train_data),
                                 valid_data, test_data, hits=[1, 3], eval_bz=100,
                                 eval_p="filtered")
        mrr = results['mrr']
        hits3 = results['hits@3']
        hits1 = results['hits@1']
        print(f"epoch {epoch} MRR {mrr:.2f} hits@1 {hits1:.2f} hits@3 {hits3:.2f}", end="")
        if best_hits3 < hits3:
            best_hits3 = hits3
        if best_mrr <= mrr:
            best_mrr = mrr
            checkpoint = {'state_dict': rgcn_model.state_dict(), 'epoch': epoch}
            print(f"*** | ", end="")

        else:
            print(f" | ", end="")
        # if hits3 == 1:
        #    break

        if epoch >= 5000:
            break


training triples size 20000


  norm = 1.0 / in_deg


epoch 100 MRR 0.00 hits@1 0.00 hits@3 0.00*** | epoch 200 MRR 0.00 hits@1 0.00 hits@3 0.00*** | epoch 300 MRR 0.00 hits@1 0.00 hits@3 0.00*** | epoch 400 MRR 0.00 hits@1 0.00 hits@3 0.00*** | epoch 500 MRR 0.00 hits@1 0.00 hits@3 0.00*** | epoch 600 MRR 0.01 hits@1 0.00 hits@3 0.01*** | epoch 700 MRR 0.01 hits@1 0.01 hits@3 0.02*** | epoch 800 MRR 0.01 hits@1 0.01 hits@3 0.01*** | epoch 900 MRR 0.02 hits@1 0.01 hits@3 0.02*** | epoch 1000 MRR 0.03 hits@1 0.02 hits@3 0.03*** | epoch 1100 MRR 0.03 hits@1 0.02 hits@3 0.03*** | epoch 1200 MRR 0.03 hits@1 0.02 hits@3 0.03*** | epoch 1300 MRR 0.04 hits@1 0.03 hits@3 0.04*** | epoch 1400 MRR 0.04 hits@1 0.03 hits@3 0.04*** | epoch 1500 MRR 0.04 hits@1 0.03 hits@3 0.04*** | epoch 1600 MRR 0.04 hits@1 0.03 hits@3 0.05*** | epoch 1700 MRR 0.05 hits@1 0.04 hits@3 0.06*** | epoch 1800 MRR 0.05 hits@1 0.04 hits@3 0.05*** | epoch 1900 MRR 0.05 hits@1 0.04 hits@3 0.06*** | epoch 2000 MRR 0.05 hits@1 0.04 hits@3 0.05 | epoch 2100 MRR 0.05 hits@1 0.04 

In [20]:

# use best model checkpoint
print("Using best epoch: {}".format(checkpoint['epoch']))
rgcn_model.eval()
rgcn_model.load_state_dict(checkpoint['state_dict'])
rgcn_embed = rgcn_model(test_graph, test_node_id, test_rel, test_norm)
rgcn_results = utils.calc_mrr(rgcn_embed, rgcn_model.w_relation, torch.LongTensor(train_data), valid_data[:limit],
                         test_data[:limit], hits=[1, 3], eval_bz=100, eval_p="filtered")
mrr = rgcn_results['mrr']
hits3 = rgcn_results['hits@3']
hits1 = rgcn_results['hits@1']
print(f"RGCN {dataset} MRR {mrr} hits@1 {hits1} hits@3 {hits3}")

Using best epoch: 5000
RGCN FB15K MRR 0.08570608496665955 hits@1 0.06575000286102295 hits@3 0.09000000357627869
