In [1]:
from rdflib import Graph
import re
import json
import string
from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk
from log_reg_categorizer import LogRegCategorizer
from smart_dataset.evaluation.dbpedia.evaluate import load_type_hierarchy, evaluate, load_ground_truth, load_system_output

In [2]:
from smart_dataset.evaluation.dbpedia.evaluate import load_type_hierarchy
type_hier = load_type_hierarchy('./smart_dataset/evaluation/dbpedia/dbpedia_types.tsv')

Loading type hierarchy from ./smart_dataset/evaluation/dbpedia/dbpedia_types.tsv... 761 types loaded (max depth: 7)


In [3]:
g = Graph()
instance_types = dict()
g.parse('./data/instance_types_en.ttl', format='n3')
for subj, pred, obj in g:
    if str(pred) == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type':
        uri = str(subj)
        type = str(obj)
        if type == 'http://www.w3.org/2002/07/owl#Thing':
            type = 'owl:Thing'
        elif type.startswith('http://dbpedia.org/ontology/'):
            type = type.split(' ')
            type = type[0].split("/")[-1].strip()
            type = 'dbo:' + type
        instance_types[uri] = type

In [4]:
g = Graph()
short_abstracts = dict()
g.parse('./data/short_abstracts_en.ttl', format='n3')
for subj, pred, obj in g:
    if str(pred) == 'http://www.w3.org/2000/01/rdf-schema#comment':
        uri = str(subj)
        if uri in instance_types.keys():
            type = instance_types[uri]
            name = uri.split(' ')
            name = name[0].split("/")[-1].strip()
            name = re.sub(r'__\d+', '', name)
            name = name.replace('_', ' ')
            elem = {'name': name, 'type': type, 'comment': str(obj)}
            short_abstracts[uri] = elem

In [5]:
es = Elasticsearch()
INDEX_NAME = 'ec_index2'

INDEX_SETTINGS = {    
    'settings' : {
        'index' : {
            "number_of_shards" : 1,
            "number_of_replicas" : 1
        },
        'analysis': {
            'analyzer': {
                'my_english_analyzer': {
                    'type': "custom",
                    'tokenizer': "standard",
                    'stopwords': "_english_",
                    'filter': [
                        "lowercase",
                        "english_stop",
                        "filter_english_minimal"
                    ]                
                }
            },
            'filter' : {
                'filter_english_minimal' : {
                    'type': "stemmer",
                    'name': "minimal_english"
                },
                'english_stop': {
                    'type': "stop",
                    'stopwords': "_english_"
                }
            },
        }
    },
    'mappings': {
        'properties': {
            'name': {
                'type': "text",
                'term_vector': "with_positions",
                'analyzer': "my_english_analyzer"
            },
            'types': {
                'type': "text",
                'term_vector': "with_positions",
                'analyzer': "my_english_analyzer"
            },
            'comment': {
                'type': "text",
                'term_vector': "with_positions",
                'analyzer': "my_english_analyzer"
            },

        }
    }
}


if es.indices.exists(index=INDEX_NAME):
    es.indices.delete(index=INDEX_NAME)

es.indices.create(index=INDEX_NAME, body=INDEX_SETTINGS)



{'acknowledged': True, 'shards_acknowledged': True, 'index': 'ec_index2'}

In [6]:
def insert_data(index_name, data):
    for key, val in data.items():
        yield {
                '_index': index_name,
                '_type': '_doc',
                '_id': key,
                '_source': val,
            }

In [7]:
for success, info in parallel_bulk(es, insert_data(INDEX_NAME, short_abstracts), 
                                        chunk_size=1000, thread_count=16, queue_size=16):  
    if not success:
        print('A document failed:', info)



In [9]:
res = es.search(index=INDEX_NAME, q = 'What is the capital of Norway?', size = 5)['hits']['hits']
print(res[0]['_source']['type'].split(' '))

['dbo:AdministrativeRegion']


In [10]:
test_questions = json.load(open('./data/smarttask_dbpedia_test_questions.json'))
lrc = LogRegCategorizer('./data/smarttask_dbpedia_train.json')
baseline_output = list()

for question in test_questions:
    q_id = question['id']
    q_text = question['question']
    q_cat = lrc.predict([q_text])[0]
    if q_cat == 'boolean':
        q_type = ['boolean']
    elif q_cat == 'literal':
        q_type = lrc.predict_literal_type([q_text]).tolist()
    elif q_cat == 'resource':
        res = es.search(index=INDEX_NAME, q=q_text.translate(str.maketrans('', '', string.punctuation)), size=5)['hits']['hits']
        q_type = [hit['_source']['type'] for hit in res]
    else:
        q_type = None
    
    baseline_output.append({
        'id': q_id,
        'question': q_text,
        'category': q_cat,
        'type' : q_type
    })



In [11]:

with open('baseline_entity_cent_results.json', 'w') as outfile:
    json.dump(baseline_output, outfile)

In [13]:
so = load_system_output('./baseline_entity_cent_results.json')
gt = load_ground_truth('./data/smarttask_dbpedia_test.json', type_hier[0].keys())
evaluate(so, gt, type_hier[0], 7)

Loading system predictions from ./baseline_entity_cent_results.json... 
   4369 predictions loaded
Loading ground truth from ./data/smarttask_dbpedia_test.json... 
   4369 questions loaded


Evaluation results:
-------------------
Category prediction (based on 4369 questions)
  Accuracy: 0.939
Type ranking (based on 4369 questions)
  NDCG@5:  0.538
  NDCG@10: 0.517
