In [1]:
#### Install necessary packages

!pip install crosslingual-coreference==0.2.3 spacy-transformers==1.1.5 wikipedia neo4j
!pip install --upgrade google-cloud-storage
!pip install transformers==4.18.0
!python -m spacy download en_core_web_sm
!pip install py2neo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting crosslingual-coreference==0.2.3
  Downloading crosslingual_coreference-0.2.3-py3-none-any.whl (11 kB)
Collecting spacy-transformers==1.1.5
  Downloading spacy_transformers-1.1.5-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 153 kB/s 
[?25hCollecting wikipedia
  Downloading wikipedia-1.4.0.tar.gz (27 kB)
Collecting neo4j
  Downloading neo4j-5.0.1.tar.gz (172 kB)
[K     |████████████████████████████████| 172 kB 59.1 MB/s 
[?25hCollecting torch<1.11.0,>=1.10.0
  Downloading torch-1.10.2-cp37-cp37m-manylinux1_x86_64.whl (881.9 MB)
[K     |██████████████████████████████▎ | 834.1 MB 1.2 MB/s eta 0:00:41tcmalloc: large alloc 1147494400 bytes == 0x39c74000 @  0x7fdb31c8e615 0x58e046 0x4f2e5e 0x4d19df 0x51b31c 0x5b41c5 0x58f49e 0x51b221 0x5b41c5 0x58f49e 0x51837f 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4ba70a 0x538136 0x590055 0x51b180

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.18.0
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 35.1 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 12.6 MB/s 
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.13.0
    Uninstalling tokenizers-0.13.0:
      Successfully uninstalled tokenizers-0.13.0
  Attempting uninstall: transformers
    Found existing installation: transformers 4.17.0
    Uninstalling transformers-4.17.0:
      Successfully uninstalled transformers-4.17.0
[31mERROR: pip's dependency resolver 

In [2]:
# Add rebel component https://github.com/Babelscape/rebel/blob/main/spacy_component.py
import requests
import re
import hashlib
import json
import csv
import io
import torch
import spacy
import crosslingual_coreference
import pandas as pd
import wikipedia
from spacy import Language
from spacy.tokens import Doc, Span
from typing import List
from py2neo import Graph
from neo4j import GraphDatabase, basic_auth
from transformers import pipeline

[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


In [3]:
def call_wiki_api(item):
  try:
    url = f"https://www.wikidata.org/w/api.php?action=wbsearchentities&search={item}&language=en&format=json"
    data = requests.get(url).json()
    # Return the first id (Could upgrade this in the future)
    return data['search'][0]['id']
  except:
    return 'id-less'

def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets


@Language.factory(
    "rebel",
    requires=["doc.sents"],
    assigns=["doc._.rel"],
    default_config={
        "model_name": "Babelscape/rebel-large",
        "device": 0,
    },
)
class RebelComponent:
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: int,
    ):
        assert model_name is not None, ""
        self.triplet_extractor = pipeline("text2text-generation", model=model_name, tokenizer=model_name, device=device)
        self.entity_mapping = {}
        # Register custom extension on the Doc
        if not Doc.has_extension("rel"):
          Doc.set_extension("rel", default={})

    def get_wiki_id(self, item: str):
        mapping = self.entity_mapping.get(item)
        if mapping:
          return mapping
        else:
          res = call_wiki_api(item)
          self.entity_mapping[item] = res
          return res

    
    def _generate_triplets(self, sent: Span) -> List[dict]:
          output_ids = self.triplet_extractor(sent.text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
          extracted_text = self.triplet_extractor.tokenizer.batch_decode(output_ids[0])
          extracted_triplets = extract_triplets(extracted_text[0])
          return extracted_triplets

    def set_annotations(self, doc: Doc, triplets: List[dict]):
        for triplet in triplets:

            # Remove self-loops (relationships that start and end at the entity)
            if triplet['head'] == triplet['tail']:
                continue

            # Use regex to search for entities
            head_span = re.search(triplet["head"], doc.text)
            tail_span = re.search(triplet["tail"], doc.text)

            # Skip the relation if both head and tail entities are not present in the text
            # Sometimes the Rebel model hallucinates some entities
            if not head_span or not tail_span:
              continue

            index = hashlib.sha1("".join([triplet['head'], triplet['tail'], triplet['type']]).encode('utf-8')).hexdigest()
            if index not in doc._.rel:
                # Get wiki ids and store results
                doc._.rel[index] = {"relation": triplet["type"], "head_span": {'text': triplet['head'], 'id': self.get_wiki_id(triplet['head'])}, "tail_span": {'text': triplet['tail'], 'id': self.get_wiki_id(triplet['tail'])}}

    def __call__(self, doc: Doc) -> Doc:
        for sent in doc.sents:
            sentence_triplets = self._generate_triplets(sent)
            self.set_annotations(doc, sentence_triplets)
        return doc

In [4]:

import torch
if torch.cuda.is_available():
    !nvidia-smi
else:
    print("☹️")

Thu Oct  6 03:36:11 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P8     9W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
DEVICE = -1 # Number of the GPU, -1 if want to use CPU

# Add coreference resolution model
coref = spacy.load('en_core_web_sm', disable=['ner', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])
coref.add_pipe(
    "xx_coref", config={"chunk_size": 2500, "chunk_overlap": 2, "device": DEVICE})


# Define rel extraction model
rel_ext = spacy.load('en_core_web_sm', disable=['ner', 'lemmatizer', 'attribute_rules', 'tagger'])
rel_ext.add_pipe("rebel", config={
    'device':DEVICE, # Number of the GPU, -1 if want to use CPU
    'model_name':'Babelscape/rebel-large'} # Model used, will default to 'Babelscape/rebel-large' if not given
    )

models/crosslingual-coreference/minilm/model.tar.gz: 358490KB [00:19, 18426.69KB/s]                            
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/489 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.83M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/8.66M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/225M [00:00<?, ?B/s]

Some weights of the model checkpoint at nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large were not used when initializing XLMRobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaModel were not initialized from the model checkpoint at nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-st

Downloading:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.51G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/780k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/123 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/344 [00:00<?, ?B/s]

<__main__.RebelComponent at 0x7f0b0cbe0890>

In [6]:
#### Test Input & Output
input_text = "Amyotrophic lateral sclerosis (ALS) is a progressive neurodegenerative disease that affects motor neurons. Mutations in the SPTLC1 subunit of serine palmitoyltransferase (SPT), which catalyzes the first step in the de novo synthesis of sphingolipids (SLs), cause childhood-onset ALS. SPTLC1-ALS variants map to a transmembrane domain that interacts with ORMDL proteins, negative regulators of SPT activity."

coref_text = coref(input_text)._.resolved_text

doc = rel_ext(coref_text)

for value, rel_dict in doc._.rel.items():
    print(f"{value}: {rel_dict}")

56a55aa697d77a4b1ce80ef589a69cd065fda679: {'relation': 'subclass of', 'head_span': {'text': 'Amyotrophic lateral sclerosis', 'id': 'Q206901'}, 'tail_span': {'text': 'neurodegenerative disease', 'id': 'Q1755122'}}
2d840fbfc8840ce10d3b9ee4fd41d7ca3c58e74f: {'relation': 'subclass of', 'head_span': {'text': 'SPTLC1', 'id': 'Q18035483'}, 'tail_span': {'text': 'serine palmitoyltransferase', 'id': 'Q64428867'}}
d3d8724c4b6d2569dd14a9b2c49aed4e0d1712ad: {'relation': 'has part', 'head_span': {'text': 'ORMDL', 'id': 'Q29731643'}, 'tail_span': {'text': 'transmembrane domain', 'id': 'Q7834587'}}


In [7]:
Inputfile = "./pmid"
Format = "pubtator"
Bioconcept = ""

In [10]:
#### Extract & Store the triplets

def SubmitPMIDList(Pmid,Format,Bioconcept):
    
    json = {}

    with io.open(Inputfile,'r',encoding="utf-8") as file_input:
        json = {"pmids": [pmid.strip() for pmid in file_input.readlines()]}
    
    if Bioconcept != "": 
        json["concepts"] = Bioconcept.split(",")

    r = requests.post("https://www.ncbi.nlm.nih.gov/research/pubtator-api/publications/export/" + Format , json = json)

    data = r.text.splitlines()
    result = data
    title = data[0][11:]
    abstract = data[1][11:]

    result = title + abstract

    if r.status_code != 200 :
        print ("[Error]: HTTP code "+ str(r.status_code))
    else:
        print(result)
        return result

def store_pubtator_summary(i):
  try:
    input_text = SubmitPMIDList([i], Format, Bioconcept)
    coref_text = coref(input_text)._.resolved_text
    doc = rel_ext(coref_text)
    params = [rel_dict for value, rel_dict in doc._.rel.items()]
    run_query(import_query, {'data': params})
  except Exception as e:
    print(f"Couldn't parse text for {i} due to {e}")

def run_query(query, params={}):
    with driver.session() as session:
        result = session.run(query, params)
        return pd.DataFrame([r.values() for r in result], columns=result.keys())

with io.open(Inputfile,'r',encoding="utf-8") as file_input:
  pmidlist = {"pmids": [pmid.strip() for pmid in file_input.readlines()]}

for i in pmidlist['pmids']:
  store_pubtator_summary(i)

SPTLC1 variants associated with ALS produce distinct sphingolipid signatures through impaired interaction with ORMDL proteins.Amyotrophic lateral sclerosis (ALS) is a progressive neurodegenerative disease affecting motor neurons. Mutations in the SPTLC1 subunit of serine-palmitoyltransferase (SPT), which catalyzes the first step in the de novo synthesis of sphingolipids cause childhood-onset ALS. SPTLC1-ALS variants map to a transmembrane domain that interacts with ORMDL proteins, negative regulators of SPT activity. We show that ORMDL binding to the holoenzyme complex is impaired in cells expressing pathogenic SPTLC1-ALS alleles, resulting in increased sphingolipid synthesis and a distinct lipid signature. C-terminal SPTLC1 variants cause the peripheral sensory neuropathy HSAN1 due to the synthesis of 1-deoxysphingolipids (1-deoxySLs) that form when SPT metabolizes L-alanine instead of L-serine. Limiting L-serine availability in SPTLC1-ALS expressing cells increased 1-deoxySL and shif

In [11]:
#### Send to Neo4j

host = 'bolt://3.87.24.194:7687' ## CHANGES WHEN NEW SERVER IS UP
user = 'neo4j'
password = 'applicant-electron-net' ## CHANGES WHEN NEW SERVER IS UP
driver = GraphDatabase.driver(host,auth=(user, password))

import_query = """
UNWIND $data AS row
MERGE (h:Entity {id: CASE WHEN NOT row.head_span.id = 'id-less' THEN row.head_span.id ELSE row.head_span.text END})
ON CREATE SET h.text = row.head_span.text
MERGE (t:Entity {id: CASE WHEN NOT row.tail_span.id = 'id-less' THEN row.tail_span.id ELSE row.tail_span.text END})
ON CREATE SET t.text = row.tail_span.text
WITH row, h, t
CALL apoc.merge.relationship(h, toUpper(replace(row.relation,' ', '_')),
  {},
  {},
  t,
  {}
)
YIELD rel
RETURN distinct 'done' AS result;
"""