In [1]:
import  os
import  re
import glob

import  pandas as pd
import  numpy as np

from src.utils import *
from src.models import TransE, rTransE

from src.env import Env
from src.agent import DQN_Network, ExperienceReplay

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CLG_dbs = [ { 'path'      : 'datasets/clg/clg_10e4/',
                'train_file'  : 'clg_10e4-train.nt',
                'test_file'   : 'clg_10e4-test.nt-e'}]

In [3]:
for db_ in CLG_dbs:
    path = db_['path']
    train_file= db_['train_file']
    test_file= db_['test_file']
    
    print('Running...', train_file, test_file)
    
    train_df = load_clg_files(path+train_file)
    
    

Running... clg_10e4-train.nt clg_10e4-test.nt-e


In [4]:
train_df.head()

Unnamed: 0,s,o,p
0,<http://caligraph.org/ontology/Mosque_complete...,<http://caligraph.org/ontology/Architectural_s...,<http://www.w3.org/2000/01/rdf-schema#subClassOf>
1,<http://caligraph.org/ontology/Sports_league_d...,<http://caligraph.org/ontology/Organization_di...,<http://www.w3.org/2000/01/rdf-schema#subClassOf>
2,<http://caligraph.org/ontology/Synagogue_compl...,<http://caligraph.org/ontology/Religious_build...,<http://www.w3.org/2000/01/rdf-schema#subClassOf>
3,<http://caligraph.org/ontology/Student_organiz...,<http://caligraph.org/ontology/1890s_establish...,<http://www.w3.org/2000/01/rdf-schema#subClassOf>
4,<http://caligraph.org/ontology/Art_museum_dise...,<http://caligraph.org/ontology/Agent>,<http://www.w3.org/2000/01/rdf-schema#subClassOf>


In [5]:
num_props = 0
triplet_rels = []
for prop in set(train_df['p']):
    #print(prop,len(train_df[train_df['p']==prop]))
    if len(train_df[train_df['p']==prop]) > 10:
        num_props+=1
        #print(prop)
        triplet_rels.append(prop)

In [6]:
node_count = 0
rels_count = 0
triplets   = []
node_dict  = dict({})
rels_dict  = dict({})

for r_ in triplet_rels:
    for triplet in train_df[train_df['p']==r_].iterrows():
        if triplet[1]['p'] not in rels_dict.keys():
                rels_dict[triplet[1]['p']]=rels_count
                rels_count+=1
                
        if triplet[1]['o'] not in node_dict.keys():
                node_dict[triplet[1]['o']]=node_count
                node_count+=1
                
        if triplet[1]['s'] not in node_dict.keys():
                node_dict[triplet[1]['s']]=node_count
                node_count+=1
                
        triplets.append((node_dict[triplet[1]['o']],rels_dict[triplet[1]['p']],node_dict[triplet[1]['s']]))
                
assert rels_count == len(triplet_rels)


In [7]:
print(node_count,rels_count,len(triplets))

24526 22 127764


In [8]:
triplets[:10]

[(0, 0, 1),
 (2, 0, 3),
 (4, 0, 5),
 (6, 0, 7),
 (8, 0, 9),
 (10, 0, 11),
 (12, 0, 13),
 (14, 0, 15),
 (16, 0, 17),
 (18, 0, 19)]

In [9]:
model_TransE  = TransE(node_count,rels_count)

In [10]:
model_TransE._train(triplets,[]);

epoch 0,	 train loss 1.14
epoch 50,	 train loss 0.83
epoch 100,	 train loss 0.74


In [11]:
env   = Env(triplets)
agent = DQN_Network([60, 64, 2],lr=1e-3)
agent_samples = agent.train(env,
        model_TransE.entity_embds.detach().numpy(),
        model_TransE.rel_embds.detach().numpy(),
        episodes = 10000,eps_decay_rate=0.9995)

epoch 0	ep_len 6	average loss 0.65	reward -0.60	done False	eps 1.00
epoch 1000	ep_len 6	average loss 0.43	reward -0.40	done False	eps 0.61
epoch 2000	ep_len 1	average loss 0.34	reward 1.00	done True	eps 0.37
epoch 3000	ep_len 1	average loss 0.12	reward 1.00	done True	eps 0.22
epoch 4000	ep_len 6	average loss 0.04	reward -0.60	done False	eps 0.14
epoch 5000	ep_len 6	average loss 0.16	reward -0.60	done False	eps 0.08
epoch 6000	ep_len 6	average loss 0.69	reward -0.40	done False	eps 0.05
epoch 7000	ep_len 1	average loss 0.38	reward 1.00	done True	eps 0.03
epoch 8000	ep_len 1	average loss 0.28	reward 1.00	done True	eps 0.02
epoch 9000	ep_len 1	average loss 0.27	reward 1.00	done True	eps 0.01


In [13]:
len(agent_samples)

678

In [30]:
def process_agent_kpaths(new_samples,K=5,reward_th=.9):
    uniques=0
    nounique=0
    unique_agent_samples = []
    for new_sample in new_samples:
        if len(new_sample)==(K+1) and new_sample[-1]>reward_th:
            unique_agent_samples.append(new_sample)  
    return unique_agent_samples

In [31]:
agent_quads = process_agent_kpaths(agent_samples)

In [32]:
len(agent_quads)

3

In [33]:
agent_quads

[[1196, 0, 518, 0, 2932, 1.1],
 [647, 0, 2159, 0, 566, 1.1],
 [14, 0, 3955, 0, 2624, 1.1]]