In [1]:
import json

split = 'dev_set'

train_conversations_path = './data/%s/%s_ALL.json' % (split, split)
PREFIX_E = 'http://www.wikidata.org/entity/'

# load training set
with open(train_conversations_path, "r") as data:
    conversations = json.load(data)
print("%d conversations loaded"%len(conversations))

2240 conversations loaded


In [2]:
# load KG
import re

from hdt import HDTDocument, TripleComponentRole
from settings import *

hdt_file = 'wikidata20200309.hdt'
kg = HDTDocument(hdt_path+hdt_file)
namespace = 'predef-wikidata2020-03-all'
predicates_ids = []
kg.configure_hops(1, predicates_ids, namespace, True, False)

In [None]:
# load a sample conversation as a sequence of entities
import numpy as np
import scipy.sparse as sp
from collections import Counter

subgraph_entities = []
subgraph_relations = []

def retrieve_subgraph(kg, matched_entity_ids, max_triples=500000000, offset=0):
    entity_ids, predicate_ids, adjacencies = [], [], []
    while True:
        _entity_ids, _predicate_ids, _adjacencies = kg.compute_hops(matched_entity_ids, max_triples, offset)
        if not _entity_ids:
            return entity_ids, predicate_ids, adjacencies
        # accumulate all splits
        entity_ids.extend(_entity_ids)
        predicate_ids.extend(_predicate_ids)
        adjacencies.extend(_adjacencies)
        offset += max_triples


for j, conversation in enumerate(conversations[:]):
    print(j)
    seed_entity = PREFIX_E + conversation['seed_entity'].split('/')[-1]
    seed_entity_text = conversation['seed_entity_text']
    seed_entity_id = kg.string_to_global_id(seed_entity, TripleComponentRole.OBJECT)
    matched_entity_ids = [seed_entity_id]
    n_questions = len(conversation['questions'])
    
    questions = []
    answer_entities = []
    answer_texts = []
    answer_ids = []
    for i in range(n_questions):
        question = conversation['questions'][i]['question']
        questions.append(question)
        
        answers = conversation['questions'][i]['answer'].split(';')
        answer_text = conversation['questions'][i]['answer_text']
        answer_texts.append(answer_text)

        _answer_entities = []
        _answer_ids = []
        for answer in answers:
            # consider only answers which are entities
            if ('www.wikidata.org' in answer):
                entity = PREFIX_E + answer.split('/')[-1]
                _answer_entities.append(entity)
                _answer_ids.append(kg.string_to_global_id(entity, TripleComponentRole.OBJECT))
        
        answer_entities.append(_answer_entities)
        answer_ids.append(_answer_ids)
    
    matched_entity_ids.extend([a for _as in answer_ids for a in _as if a])
    
    # retrieve relevant subgraph
    entity_ids, predicate_ids, adjacencies = retrieve_subgraph(kg, matched_entity_ids)
    
    subgraph_entities.append(len(entity_ids))
    subgraph_relations.append(len(predicate_ids))
    
    # dump sample with subgraph as json
    data = {'seed_entity': seed_entity, 'seed_entity_text': seed_entity_text, 'seed_entity_id': seed_entity_id,
            'questions': questions,
            'answer_entities': answer_entities, 'answer_texts': answer_texts, 'answer_ids': answer_ids,
            'entities': entity_ids, 'predicates': predicate_ids, 'adjacencies': adjacencies}
    json_object = json.dumps(data)
    with open('./data/subgraphs/%s/%d.json' % (split, j), "w") as outfile:
        outfile.write(json_object) 
#     break

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14


In [None]:
def stats(_list):
    print(min(_list), np.mean(_list), max(_list))
    print(Counter(_list))
    print('\n')

stats(subgraph_entities)
stats(subgraph_relations)