In [1]:
import rdflib
import pandas as pd

from tqdm import tqdm_notebook

from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import sys
sys.path.append('..')
from tree_builder import KGPTree, KGPForest
from datastructures import *
import time

import pickle

In [2]:
rdf_file = 'data/AIFB/aifb.n3'
_format = 'n3'
train_file = 'data/AIFB/AIFB_test.tsv'
test_file = 'data/AIFB/AIFB_train.tsv'
entity_col = 'person'
label_col = 'label_affiliation'
label_predicates = [
    rdflib.URIRef('http://swrc.ontoware.org/ontology#affiliation'),
    rdflib.URIRef('http://swrc.ontoware.org/ontology#employs'),
    rdflib.URIRef('http://swrc.ontoware.org/ontology#carriedOutBy')
]
output = 'output/aifb_depth10.p'

# Single KG Path Tree

In [3]:
print(end='Loading data... ', flush=True)
g = rdflib.Graph()
g.parse(rdf_file, format=_format)
print('OK')

test_data = pd.read_csv(train_file, sep='\t')
train_data = pd.read_csv(test_file, sep='\t')

train_entities = [rdflib.URIRef(x) for x in train_data[entity_col]]
train_labels = train_data[label_col]

test_entities = [rdflib.URIRef(x) for x in test_data[entity_col]]
test_labels = test_data[label_col]

kg = KnowledgeGraph.rdflib_to_kg(g, label_predicates=label_predicates)

clf = KGPTree(kg, path_max_depth=8, min_samples_leaf=1, max_tree_depth=None, progress=tqdm_notebook)

Loading data... OK


In [4]:
clf.fit(train_entities, train_labels)

HBox(children=(IntProgress(value=0, description='Extracting neighborhoods', max=140, style=ProgressStyle(descr…




HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7312, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7312, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7312, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7312, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=6232, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=6232, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=6232, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=6232, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=5966, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=5966, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=5966, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=5966, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3728, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3728, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3728, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3728, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3354, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3354, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3354, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3354, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3178, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3178, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3178, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3178, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2789, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2789, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2789, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2789, style=ProgressStyle(description_width…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=279, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=279, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=279, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=279, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=248, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=248, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=248, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=248, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=104, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=104, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=104, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=104, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=198, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=198, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=198, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=198, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=169, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=169, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=169, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=169, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=164, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=164, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=164, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=164, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=159, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=159, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=159, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=159, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=81, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=81, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=81, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=81, style=ProgressStyle(description_width='…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=53, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=53, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=53, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=53, style=ProgressStyle(description_width='…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=122, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=122, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=122, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='vertex loop', max=122, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=37, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=37, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=37, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=37, style=ProgressStyle(description_width='…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=34, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=34, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=34, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=34, style=ProgressStyle(description_width='…



HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=29, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=29, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=29, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='vertex loop', max=29, style=ProgressStyle(description_width='…



In [5]:
preds = clf.predict(test_entities)
print(accuracy_score(test_labels, preds))

0.9166666666666666


# Forest of KG Path Trees

In [9]:
print(end='Loading data... ', flush=True)
g = rdflib.Graph()
g.parse(rdf_file, format=_format)
print('OK')

test_data = pd.read_csv(train_file, sep='\t')
train_data = pd.read_csv(test_file, sep='\t')

train_entities = [rdflib.URIRef(x) for x in train_data[entity_col]]
train_labels = train_data[label_col]

test_entities = [rdflib.URIRef(x) for x in test_data[entity_col]]
test_labels = test_data[label_col]

kg = KnowledgeGraph.rdflib_to_kg(g, label_predicates=label_predicates)

clf = KGPForest(kg, path_max_depth=8, 
                min_samples_leaf=1, 
                max_tree_depth=None,
                n_estimators=50,
                vertex_sample=0.5,
                progress=tqdm_notebook)

Loading data... OK


In [None]:
clf.fit(train_entities, train_labels)

HBox(children=(IntProgress(value=0, description='Extracting neighborhoods', max=140, style=ProgressStyle(descr…




HBox(children=(IntProgress(value=0, description='estimator loop', max=50, style=ProgressStyle(description_widt…

HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=7763, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3803, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3803, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3803, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3803, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3512, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3512, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3512, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=3512, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2387, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2387, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2387, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2387, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2032, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2032, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2032, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2032, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2179, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2179, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2179, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2179, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='depth loop', max=4, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2160, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2160, style=ProgressStyle(description_width…

HBox(children=(IntProgress(value=0, description='vertex loop', max=2160, style=ProgressStyle(description_width…

In [None]:
preds = clf.predict(test_entities)
print(accuracy_score(test_labels, preds))