In [None]:
# default_exp dataset_builder
%load_ext autoreload
%autoreload 2

In [None]:
#export 
import random
from rake_nltk import Rake
from kirby.database_proxy import WikiDatabase
import json
import importlib
import spacy
import en_core_web_sm
import tensorflow_hub as hub
from sklearn.neighbors import NearestNeighbors

# Dataset Builder
> Builds an optimal dataset with knowledge base relations, from a vanilla dataset.



In [None]:
#export
class DatasetBuilder():
    "Build a dataset using `get_entities_in_text`"
    def __init__(self):
        self.rake = Rake()
        self.db = WikiDatabase()
        self.nlp = en_core_web_sm.load()
        module_url = "https://tfhub.dev/google/universal-sentence-encoder/4" #@param ["https://tfhub.dev/google/universal-sentence-encoder/4", "https://tfhub.dev/google/universal-sentence-encoder-large/5"]
        self.encoder = hub.load(module_url)
    
    def build(self, ds, dataset_type='random'):
        "Build a database based a given dataset"
        if dataset_type == 'random':
            ds.map(self.random, batched=False)
        elif dataset_type == 'description':
            pass
        elif dataset_type == 'relevant':
            pass
        
    def keyword(self, x):
        ranked_phrases = self.get_ranked_phrases(x)
        return ranked_phrases[0]
    
    def get_ranked_phrases(self, x):
        self.rake.extract_keywords_from_text(x)
        return self.rake.get_ranked_phrases()
    
    #staticmethod
    def add_to_accepted(self, a_sentences, sentence):
        if len(a_sentences) > 2:
            a_sentences.pop(0)
        a_sentences.append(sentence)
    

    def get_entities_in_text(self, text):
        "Returns entities found in the sentence `x`"
        doc = self.nlp(x)
        entities = []
        spacy_entities = doc.ents
        for entity in spacy_entities: 
            entity = self.db.get_entity_by_label(entity.text)
            entities.append(entity)
        return entities

    def entity(self, ranked_phrases):
        "Queries the knowledge base to find the entity and it's relations"
        for phrase in ranked_phrases:
            entity = self.kba.get_entity(phrase)
            if entity is not None:
                return entity
        return entity   
    def get_entity_associations(self, entity_id):
        """
        Given an `entity_id` return a dictionary containing all the associated properties.
        """
        entity_associations_dict = {}
        # Remove all None values from list
        associations = self.db.get_entity_associations(entity_id)
        for property_id, related_entity_id in associations: 
            property_name, related_entity_label = self.db.get_property_string(property_id, related_entity_id)
            entity_associations_dict[property_name] = related_entity_label
        return entity_associations_dict

# Testing

In [None]:
# creation
ds_builder = DatasetBuilder()
assert isinstance(ds_builder, DatasetBuilder)

<sqlite3.Connection object at 0x7f6bce64fc60>


In [None]:
# Test ranked phrases
x = "Stephen Curry is my favorite basketball player."
ds_builder.rake.extract_keywords_from_text(x)
ranked_phrases = ds_builder.rake.get_ranked_phrases()
assert ranked_phrases == ['favorite basketball player', 'stephen curry'], "RAKE failed"

In [None]:
print(ds_builder.db.get_entity_by_label('Cristiano Ronaldo'))
assert ds_builder.db.get_entity_by_label('Cristiano Ronaldo') == ['Q11571', 'Cristiano Ronaldo', 'Portuguese association football player'], 'ERROR in `database_proxy`'

['Q11571', 'Cristiano Ronaldo', 'Portuguese association football player']


In [None]:
# Get Entities from the sentence
x = "Microsoft has bought Bethesda"
entities = ds_builder.get_entities_in_text(x)
print(entities)
assert entities == [['Q2283', 'Microsoft', 'American multinational technology corporation'],\
                    ['Q224892', 'Bethesda', 'Wikimedia disambiguation page']],\
                    'Error in `dataset_builder.get_entities_in_text()`'

[['Q2283', 'Microsoft', 'American multinational technology corporation'], ['Q224892', 'Bethesda', 'Wikimedia disambiguation page']]


In [None]:
# Get associations from an entity
associations = ds_builder.get_entity_associations(entities[0][0])
assert associations == {"topic's main Wikimedia portal": 'Portal:Microsoft',
 'founded by': 'Bill Gates',
 'country': 'United States of America',
 'instance of': 'software company',
 'headquarters location': 'Redmond',
 'stock exchange': 'NASDAQ',
 'chief executive officer': 'Steve Ballmer',
 "topic's main category": 'Category:Microsoft',
 'subsidiary': 'Xbox Game Studios',
 'described by source': 'Lentapedia (full versions)',
 'industry': 'technology industry',
 'product or material produced': 'Microsoft Windows',
 'legal form': 'Washington corporation',
 'business division': 'Microsoft Research',
 'history of topic': 'history of Microsoft',
 'member of': 'Alliance for Open Media',
 'permanent duplicated item': None,
 'part of': 'NASDAQ-100',
 'award received': 'Big Brother Awards',
 'owner of': 'Microsoft TechNet',
 'owned by': 'BlackRock',
 'board member': 'Reid Hoffman',
 'chairperson': 'John W. Thompson',
 'location of formation': 'Albuquerque',
 'director / manager': 'Satya Nadella',
 'external auditor': 'Deloitte & Touche LLP',
 'partnership with': 'ID2020'}, 'Error in `dataset_builder.get_entity_associations()`'