<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/ie_pipeline/SpaCy_informationextraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install crosslingual-coreference spacy-transformers wikipedia neo4j
!pip install --upgrade google-cloud-storage
!pip install --upgrade transformers
!python -m spacy download en_core_web_sm


Collecting transformers<4.19,>=4.1
  Using cached transformers-4.17.0-py3-none-any.whl (3.8 MB)
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.18.0
    Uninstalling transformers-4.18.0:
      Successfully uninstalled transformers-4.18.0
Successfully installed transformers-4.17.0
Collecting transformers
  Using cached transformers-4.18.0-py3-none-any.whl (4.0 MB)
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.17.0
    Uninstalling transformers-4.17.0:
      Successfully uninstalled transformers-4.17.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
spacy-transformers 1.1.5 requires transformers<4.18.0,>=3.4.0, but you have transformers 4.18.0 which is incompatible.[0m
Successfully installed tran

# Restart runtime

In [None]:
import spacy
import crosslingual_coreference

[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [None]:
# Add rebel component https://github.com/Babelscape/rebel/blob/main/spacy_component.py
import requests
import hashlib
from spacy import Language
from typing import List

from spacy.tokens import Doc, Span

from transformers import pipeline

def call_wiki_api(item):
  try:
    url = f"https://www.wikidata.org/w/api.php?action=wbsearchentities&search={item}&language=en&format=json"
    data = requests.get(url).json()
    # Return the first id (Could upgrade this in the future)
    return data['search'][0]['id']
  except:
    return 'id-less'

def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets


@Language.factory(
    "rebel",
    requires=["doc.sents"],
    assigns=["doc._.rel"],
    default_config={
        "model_name": "Babelscape/rebel-large",
        "device": 0,
    },
)
class RebelComponent:
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: int,
    ):
        assert model_name is not None, ""
        self.triplet_extractor = pipeline("text2text-generation", model=model_name, tokenizer=model_name, device=device)
        self.entity_mapping = {}
        # Register custom extension on the Doc
        if not Doc.has_extension("rel"):
          Doc.set_extension("rel", default={})

    def get_wiki_id(self, item: str):
        mapping = self.entity_mapping.get(item)
        if mapping:
          return mapping
        else:
          res = call_wiki_api(item)
          self.entity_mapping[item] = res
          return res

    
    def _generate_triplets(self, sent: Span) -> List[dict]:
          output_ids = self.triplet_extractor(sent.text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
          extracted_text = self.triplet_extractor.tokenizer.batch_decode(output_ids[0])
          extracted_triplets = extract_triplets(extracted_text[0])
          return extracted_triplets

    def set_annotations(self, doc: Doc, triplets: List[dict]):
        for triplet in triplets:

            # Remove self-loops (relationships that start and end at the entity)
            if triplet['head'] == triplet['tail']:
                continue

            index = hashlib.sha1("".join([triplet['head'], triplet['tail'], triplet['type']]).encode('utf-8')).hexdigest()
            if index not in doc._.rel:
                # Get wiki ids and store results
                doc._.rel[index] = {"relation": triplet["type"], "head_span": {'text': triplet['head'], 'id': self.get_wiki_id(triplet['head'])}, "tail_span": {'text': triplet['tail'], 'id': self.get_wiki_id(triplet['tail'])}}

    def __call__(self, doc: Doc) -> Doc:
        for sent in doc.sents:
            sentence_triplets = self._generate_triplets(sent)
            self.set_annotations(doc, sentence_triplets)
        return doc

In [None]:
DEVICE = 0 # Number of the GPU, -1 if want to use CPU

# Add coreference resolution model
coref = spacy.load('en_core_web_sm', disable=['ner', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])
coref.add_pipe(
    "xx_coref", config={"chunk_size": 2500, "chunk_overlap": 2, "device": DEVICE})

# Define rel extraction model

rel_ext = spacy.load('en_core_web_sm', disable=['ner', 'lemmatizer', 'attribute_rules', 'tagger'])
rel_ext.add_pipe("rebel", config={
    'device':DEVICE, # Number of the GPU, -1 if want to use CPU
    'model_name':'Babelscape/rebel-large'} # Model used, will default to 'Babelscape/rebel-large' if not given
    )



Some weights of the model checkpoint at microsoft/infoxlm-base were not used when initializing XLMRobertaModel: ['lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<__main__.RebelComponent at 0x7f12349c2810>

In [None]:
input_text = "Christian Drosten works in Germany. He likes to work for Google."

coref_text = coref(input_text)._.resolved_text

doc = rel_ext(coref_text)

for span in doc.ents:
    print((span.text, span.kb_id_, span.label_, span._.description, span._.score))

for value, rel_dict in doc._.rel.items():
    print(f"{value}: {rel_dict}")

385d9ac30c35635c7e36a1636ded4a091f830cf3: {'relation': 'country of citizenship', 'head_span': {'text': 'Christian Drosten', 'id': 'Q1079331'}, 'tail_span': {'text': 'Germany', 'id': 'Q183'}}
471a35571b66cc8e1e3415e27f7086505310efc6: {'relation': 'employer', 'head_span': {'text': 'Christian Drosten', 'id': 'Q1079331'}, 'tail_span': {'text': 'Google', 'id': 'Q95'}}


In [1]:
import pandas as pd
import wikipedia
from neo4j import GraphDatabase

# Define Neo4j connection
host = 'bolt://35.172.134.253:7687'
user = 'neo4j'
password = 'depositions-badge-counsel'
driver = GraphDatabase.driver(host,auth=(user, password))

import_query = """
UNWIND $data AS row
MERGE (h:Entity {id: CASE WHEN NOT row.head_span.id = 'id-less' THEN row.head_span.id ELSE row.head_span.text END})
ON CREATE SET h.text = row.head_span.text
MERGE (t:Entity {id: CASE WHEN NOT row.tail_span.id = 'id-less' THEN row.tail_span.id ELSE row.tail_span.text END})
ON CREATE SET t.text = row.tail_span.text
WITH row, h, t
CALL apoc.merge.relationship(h, toUpper(replace(row.relation,' ', '_')),
  {},
  {},
  t,
  {}
)
YIELD rel
RETURN distinct 'done' AS result;
"""


def run_query(query, params={}):
    with driver.session() as session:
        result = session.run(query, params)
        return pd.DataFrame([r.values() for r in result], columns=result.keys())

def store_wikipedia_summary(page):
  try:
    input_text = wikipedia.page(page).summary
    coref_text = coref(input_text)._.resolved_text
    doc = rel_ext(coref_text)
    params = [rel_dict for value, rel_dict in doc._.rel.items()]
    run_query(import_query, {'data': params})
  except Exception as e:
    print(f"Couldn't parse text for {page} due to {e}")


ModuleNotFoundError: ignored

In [None]:
ladies = ["Jennifer Doudna", "Rachel Carson", "Sara Seager OC", "Rosalind Franklin", "Gertrude Elion", "Rita Levi-Montalcini"]

for l in ladies:
  print(f"Parsing {l}")
  store_wikipedia_summary(l)

Parsing Jennifer Doudna


  num_effective_segments = (seq_lengths + self._max_length - 1) // self._max_length


Parsing Rachel Carson
Parsing Sara Seager OC
Parsing Rosalind Franklin
Parsing Gertrude Elion
Parsing Rita Levi-Montalcini


In [9]:
run_query("""
CALL apoc.periodic.iterate("
  MATCH (e:Entity)
  WHERE e.id STARTS WITH 'Q'
  RETURN e
","
  // Prepare a SparQL query
  WITH 'SELECT * WHERE{ ?item rdfs:label ?name . filter (?item = wd:' + e.id + ') filter (lang(?name) = \\\"en\\\") ' +
     'OPTIONAL {?item wdt:P31 [rdfs:label ?label] .filter(lang(?label)=\\\"en\\\")}}' AS sparql, e
  // make a request to Wikidata
  CALL apoc.load.jsonParams(
    'https://query.wikidata.org/sparql?query=' + 
      sparql,
      { Accept: 'application/sparql-results+json'}, null)
  YIELD value
  UNWIND value['results']['bindings'] as row
  SET e.wikipedia_name = row.name.value
  WITH e, row.label.value AS label
  CALL apoc.create.addLabels( e, [ apoc.text.capitalize(label) ] )
  YIELD node
  RETURN distinct 'done'", {batchSize:1, retry:1})
""")

Unnamed: 0,batches,total,timeTaken,committedOperations,failedOperations,failedBatches,retries,errorMessages,batch,operations,wasTerminated,failedParams,updateStatistics
0,72,72,6,67,5,5,0,{},"{'total': 72, 'committed': 67, 'failed': 5, 'e...","{'total': 72, 'committed': 67, 'failed': 5, 'e...",False,{},"{'nodesDeleted': 0, 'labelsAdded': 0, 'relatio..."
