In [1]:
import rdflib
import pandas as pd

from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import sys
sys.path.append('..')
from tree_builder import KGPTree
from datastructures import *
import time

import pickle

In [3]:
rdf_file = '../data/AIFB/aifb.n3'
_format = 'n3'
train_file = '../data/AIFB/AIFB_test.tsv'
test_file = '../data/AIFB/AIFB_train.tsv'
entity_col = 'person'
label_col = 'label_affiliation'
label_predicates = [
    rdflib.URIRef('http://swrc.ontoware.org/ontology#affiliation'),
    rdflib.URIRef('http://swrc.ontoware.org/ontology#employs'),
    rdflib.URIRef('http://swrc.ontoware.org/ontology#carriedOutBy')
]
output = 'output/aifb_depth10.p'

In [4]:
print(end='Loading data... ', flush=True)
g = rdflib.Graph()
g.parse(rdf_file, format=_format)
print('OK')

test_data = pd.read_csv(train_file, sep='\t')
train_data = pd.read_csv(test_file, sep='\t')

train_entities = [rdflib.URIRef(x) for x in train_data[entity_col]]
train_labels = train_data[label_col]

test_entities = [rdflib.URIRef(x) for x in test_data[entity_col]]
test_labels = test_data[label_col]

kg = KnowledgeGraph.rdflib_to_kg(g, label_predicates=label_predicates)

clf = KGPTree(kg, path_max_depth=6, neighborhood_depth=8, min_samples_leaf=1, max_tree_depth=5)

Loading data... OK


In [6]:
clf.fit(train_entities, train_labels)

  0%|          | 0/140 [00:00<?, ?it/s]

Extracting neighborhoods...


100%|██████████| 140/140 [00:32<00:00,  4.37it/s]
100%|██████████| 7763/7763 [00:05<00:00, 1421.83it/s]
100%|██████████| 7763/7763 [00:08<00:00, 904.22it/s] 
100%|██████████| 7763/7763 [00:09<00:00, 823.40it/s]
100%|██████████| 7312/7312 [00:01<00:00, 4004.08it/s]
100%|██████████| 7312/7312 [00:03<00:00, 1992.59it/s]
100%|██████████| 7312/7312 [00:03<00:00, 2367.51it/s]
100%|██████████| 6232/6232 [00:00<00:00, 19243.78it/s]
100%|██████████| 6232/6232 [00:00<00:00, 11858.66it/s]
100%|██████████| 6232/6232 [00:00<00:00, 8295.63it/s]
100%|██████████| 5966/5966 [00:00<00:00, 30988.27it/s]
100%|██████████| 5966/5966 [00:00<00:00, 19042.62it/s]
100%|██████████| 5966/5966 [00:01<00:00, 3939.70it/s]
100%|██████████| 3728/3728 [00:01<00:00, 2129.68it/s]
100%|██████████| 3728/3728 [00:02<00:00, 1669.59it/s]
100%|██████████| 3728/3728 [00:02<00:00, 1388.83it/s]
100%|██████████| 3354/3354 [00:00<00:00, 4889.65it/s]
100%|██████████| 3354/3354 [00:00<00:00, 3810.65it/s]
100%|██████████| 3354/3354 [0

In [7]:
preds = clf.predict(test_entities)
print(accuracy_score(test_labels, preds))

0.8611111111111112
