In [None]:
# default_exp dataset_builder
%load_ext autoreload
%autoreload 2

In [None]:
#export 
import random
from kirby.database_proxy import WikiDatabase
import json
import importlib
import spacy
import en_core_web_sm
import json

# Dataset Builder
> Builds a number of types of datasets that are augmented by the `wikidata` knowledge base

## Dataset Variations

Keywords can be supplied to the `build` function through the `dataset_type` keyword argument

### Description

`DatasetBuilder.build(ds, dataset_type='description')`

Augements the given dataset with the description of the `keyword`. 

*Example*

`Stephen Curry is my favorite basketball player. {Stephen Curry: {Description: American basketball player}}`

In [None]:
#export
class DatasetBuilder():
    "Build a dataset using `get_entities_in_text`"
    def __init__(self):
#         self.db = WikiDatabase()
        self.nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "tagger", "parser", "attribute_ruler", "lemmatizer"])

#         module_url = "https://tfhub.dev/google/universal-sentence-encoder/4" #@param ["https://tfhub.dev/google/universal-sentence-encoder/4", "https://tfhub.dev/google/universal-sentence-encoder-large/5"]
#         self.encoder = hub.load(module_url)
    
    def build(self, ds, dataset_type='random'):
        "Build a database based a given dataset"
        if dataset_type == 'random':
            return ds.map(self.random, batched=False)
        elif dataset_type == 'description':
            return ds.map(self.description, batched=False)
        elif dataset_type == 'relevant':
            pass
        
    def build_knowledge_entities(self, ds, split):
        ds = ds.map(self.get_entities, batched=True, num_proc=4)
        ds.save_to_disk('data/augmented_datasets/entities/' + split + '/')
        
    def build_csv(self, ds, split):
        ds = ds.map(self.retrieve_knowledge, batched=False)
        ds.save_to_disk('data/augmented_datasets/')
        
    def get_entities(self, batch):
        import pdb; pdb.set_trace()
        doc = self.nlp(batch)
        entities = doc.ents
        
    def retrieve_knowledge(self, sequence):
        text = sequence['text']
        entities = self.get_entities_in_text(text)
        knowledge = self.add_associations(entities)
        sequence['knowledge'] = knowledge
        return sequence
            
    def add_associations(self, entities):
        "Returns list of entity/association dictionaries"
        associations = []
        for e in entities:
            a = self.get_entity_associations(e)
            k = {e[1]: a}
            associations.append(k)
        return associations
        
    def _get_json(self, item):
        """Return JSON version of list object"""
        d = {"label": None, "description": None}
        d['label'] = item[1]
        d['description'] = item[2]
        return json.dumps(d)
    
    def get_entities_in_text(self, text):
        "Returns entities found in the sentence `text`"
        doc = self.nlp(text)
        entities = []
        spacy_entities = doc.ents
        for entity in spacy_entities: 
            entity = self.db.get_entity_by_label(entity.text)
            if entity:
                entities.append(entity)
        return entities  
    
    def get_entity_associations(self, entity):
        """
        Given an `entity_id` return a dictionary containing all the associated properties.
        """
        entity_id = entity[0]
        entity_associations_dict = {'id':entity_id, 'description':entity[2]}
        # Remove all None values from list
        associations = self.db.get_entity_associations(entity_id)
        if not associations:
            return None
        for property_id, related_entity_id in associations: 
            property_name, related_entity_label = self.db.get_property_string(property_id, related_entity_id)
            entity_associations_dict[property_name] = related_entity_label
        return entity_associations_dict

# Testing

In [None]:
# Build description dataset
from kirby.run_params import RunParams
from kirby.data_manager import DataManager
from datasets import load_dataset

run_params = RunParams(debug=True)
data_manager = DataManager(run_params)
ds_builder = DatasetBuilder()

split = 'train'
ds = data_manager.load(split)
ds_builder.build_knowledge_entities(ds, split)

split = 'valid'
ds = data_manager.load(split)
ds_builder.build_knowledge_entities(ds, split)

Using custom data configuration default-d64f335cc8a13d66
Reusing dataset text (/home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)
Loading cached processed dataset at /home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-afe4d9c5e6fd685f.arrow


In [None]:
ds

Dataset({
    features: ['text'],
    num_rows: 490
})

In [None]:
# creation
ds_builder = DatasetBuilder()

assert isinstance(ds_builder, DatasetBuilder)

In [None]:
# Test ranked phrases
x = "Stephen Curry is my favorite basketball player."
ds_builder.rake.extract_keywords_from_text(x)
ranked_phrases = ds_builder.rake.get_ranked_phrases()
assert ranked_phrases == ['favorite basketball player', 'stephen curry'], "RAKE failed"

In [None]:
print(ds_builder.db.get_entity_by_label('Cristiano Ronaldo'))
assert ds_builder.db.get_entity_by_label('Cristiano Ronaldo') == ['Q11571', 'Cristiano Ronaldo', 'Portuguese association football player'], 'ERROR in `database_proxy`'

In [None]:
# Get Entities from the sentence
x = "Microsoft has bought Bethesda"
entities = ds_builder.get_entities_in_text(x)
print(entities)
assert entities == [['Q2283', 'Microsoft', 'American multinational technology corporation'],\
                    ['Q224892', 'Bethesda', 'Wikimedia disambiguation page']],\
                    'Error in `dataset_builder.get_entities_in_text()`'

In [None]:
# Get associations from an entity
associations = ds_builder.get_entity_associations(entities[0][0])
assert associations == {"topic's main Wikimedia portal": 'Portal:Microsoft',
 'founded by': 'Bill Gates',
 'country': 'United States of America',
 'instance of': 'software company',
 'headquarters location': 'Redmond',
 'stock exchange': 'NASDAQ',
 'chief executive officer': 'Steve Ballmer',
 "topic's main category": 'Category:Microsoft',
 'subsidiary': 'Xbox Game Studios',
 'described by source': 'Lentapedia (full versions)',
 'industry': 'technology industry',
 'product or material produced': 'Microsoft Windows',
 'legal form': 'Washington corporation',
 'business division': 'Microsoft Research',
 'history of topic': 'history of Microsoft',
 'member of': 'Alliance for Open Media',
 'permanent duplicated item': None,
 'part of': 'NASDAQ-100',
 'award received': 'Big Brother Awards',
 'owner of': 'Microsoft TechNet',
 'owned by': 'BlackRock',
 'board member': 'Reid Hoffman',
 'chairperson': 'John W. Thompson',
 'location of formation': 'Albuquerque',
 'director / manager': 'Satya Nadella',
 'external auditor': 'Deloitte & Touche LLP',
 'partnership with': 'ID2020'}, 'Error in `dataset_builder.get_entity_associations()`'

In [None]:
# Description
text = "Darth Vader cut off Luke Skywalker's hand"
ds_builder.description(text)