In [None]:
from transformers import AutoTokenizer, AutoModel, AutoModelForTokenClassification,pipeline
import pandas as pd
import pickle
import os
import xmltodict
import requests
import xml.etree.ElementTree as ET
import json
import random

# CellText Mining


### Functions contained in this notebook are for extracting raw text from Cell journal articles and then running NER on the output.
    The actual text json itself is acquired from Elsevier's TDM api. You will need to request access, and use your own API key (when in development) to run the following code.
    Light cleaning is done on the output to sanitize it and remove artifacts from stemming and the tokenization process, however, the validity of the output is not guaranteed and required further validation.
    The models used for NER come from huggingface and are all based on biobert or pubmedbert, and specialized for NER.


In [None]:
## CSV of articles, must have column for DOIs of articles
input_df = pd.read_csv('/NLP/CellArticles.csv')
## If necessary, filter out those that have Open Access (don't have to, can be used for validation of Open acess TDM methods)
input_df = input_df[input_df["Open Access"].isna()]
## Import NER models from HuggingFace. Feel free to import different NER models if you wish
disease_tokenizer = AutoTokenizer.from_pretrained("alvaroalon2/biobert_diseases_ner")
disease_model = AutoModelForTokenClassification.from_pretrained("alvaroalon2/biobert_diseases_ner")
genetic_tokenizer = AutoTokenizer.from_pretrained("alvaroalon2/biobert_genetic_ner")
genetic_model = AutoModelForTokenClassification.from_pretrained("alvaroalon2/biobert_genetic_ner")
pubmedbert_gene = AutoTokenizer.from_pretrained("pruas/BENT-PubMedBERT-NER-Gene", model_max_length=512)
pubmedbert_gene_model = AutoModelForTokenClassification.from_pretrained("pruas/BENT-PubMedBERT-NER-Gene")
pubmedbert_disease = AutoTokenizer.from_pretrained("pruas/BENT-PubMedBERT-NER-Disease",model_max_length=512)
pubmedbert_disease_model = AutoModelForTokenClassification.from_pretrained("pruas/BENT-PubMedBERT-NER-Disease")

disease_nlp = pipeline("ner", model=disease_model, tokenizer=disease_tokenizer)
genetic_nlp = pipeline("ner", model=genetic_model, tokenizer=genetic_tokenizer)
pubmedbert_gene_nlp = pipeline("ner", model=pubmedbert_gene_model, tokenizer=pubmedbert_gene)
pubmedbert_disease_nlp = pipeline("ner", model=pubmedbert_disease_model, tokenizer=pubmedbert_disease)

## Whatever models you choose, seperate them out into gene and disease, and collect them as two seperate arrays

disease = [disease_nlp,pubmedbert_disease_nlp]
genetic = [genetic_nlp,pubmedbert_gene_nlp]


nlps={'disease':disease,'genetic':genetic}

In [None]:
## Use Cell API and token to query article
def getCellArticles(cell_doi):
    cell_api_key = '1f89a8d2a51cc28137532f5f47bbb032'
    cell_token= 'da050463085352e8a83c00a3fe1e7aac'
    cell_query = f'https://api.elsevier.com/content/article/doi/{cell_doi}?APIKey={cell_api_key}&insttoken={cell_token}&view=FULL'
    response = requests.get(cell_query)
    dict_data = xmltodict.parse(response.content)
    decoded_response = response.content.decode("utf-8")
    response_json = json.loads(json.dumps(xmltodict.parse(decoded_response)))
    # Turn query into json and flatten into a string
    cell_article = (list(flatten(getKey(response_json['full-text-retrieval-response']['originalText']['xocs:doc']['xocs:serial-item'],'#text'))))
    return cell_article

## Some basic data processing
def df_cleaning(df):
    new_entries = []
    ends = []
    starts = []
    entities = []
    for i, row in df.iterrows():
        # if the entry starts with '##', combine it with the previous entry
        if row['word'].startswith('##'):
            try:
                new_entries[-1] = new_entries[-1].strip() + row['word'][2:].strip() ##Strip out '##' unless unable to, in which case skip and continue
            except:
                continue
            ends[-1] = row['end']
        else:
            new_entries.append(row['word'].strip()) ##If no '##' to strip, append as is to array of results
            ends.append(row['end'])
            starts.append(row['start'])
            entities.append(row['entity'])
    concatenated_text = []
    if not new_entries:
        return(pd.DataFrame())
    current_text = new_entries[0]
    current_start = starts[0]
    current_end = ends[0]
    current_ent = entities[0]
    for i in range(1, len(new_entries)): ##Scans through all results in array, concatenating consectutive terms to form longer string results
        if starts[i] == current_end+1 or starts[i] == current_end:
            current_text = current_text + " " +new_entries[i]
            current_end = ends[i]
            current_ent = entities[i]
        else:
            concatenated_text.append((current_ent, current_text, current_start, current_end))
            current_text = new_entries[i]
            current_start = starts[i]
            current_end = ends[i]
            current_ent = entities[i]
    concatenated_text.append((current_ent, current_text, current_start, current_end))
    return pd.DataFrame(concatenated_text, columns=['entity','word', 'start', 'end'])

## Recursively extracts info from a dict. In this case we are extracting all '#text' fields
def getKey(d, key): 
    res_list = []
    fin_list = []
    try:
        res_list.append(d[key])
    except:
        try:
            for i in d.keys():
                res_list.append(getKey(d[i], key))
        except:
            if type(d) is list:
                for i in d:
                    if type(i) is dict:
                        res_list.append(getKey(i, key))
    for i in res_list:
        if i:
            fin_list.append(i)
    return(fin_list)

## Unwind json
def flatten(container):
    for i in container:
        if isinstance(i, (list,tuple)):
            for j in flatten(i):
                yield j
        else:
            yield i 

## Run NLP on string input and returns results as DF
def NER_results(nlps,string_result):
    result_df=pd.DataFrame()
    for i in nlps:
        temp_df = pd.DataFrame(i(string_result))
        temp_df = df_cleaning(temp_df)
        try:
            temp_df = temp_df[temp_df["entity"] != '0']
        except:
            pass
        result_df = pd.concat([result_df, temp_df])
    result_df.sort_values(by=['start'])
    return result_df

## Cleanup any collisons/overlaps
def collision_cleanup(model_type,full_output):
    output= pd.DataFrame()
    full_output[full_output.columns[0]] = model_type
    process_df = full_output.applymap(lambda s: s.lower() if type(s) == str else s)
    df = process_df.drop_duplicates() ##Handle simple true duplicates
    while output.equals(df) == False:
        output = df
        df = df.sort_values(['start','end'])
        c1 = df['word'].shift() == df['word']
        c2 = df['end'].shift() - df['start'] <= 0
        #c3 = df['end'].shift() - df['end'] < 0
        df['interval'] = df['end'] - df['start']
        df['overlap'] = (c1 | c2).cumsum()
        df = df.sort_values(['interval'], ascending=False).groupby('overlap').first()
        df = df.reset_index(drop=True)
    return(df)

In [None]:
success = {}
fail = []
## Run NER over all articles, if failed, store doi in array. Successful results are stored in a dictionary with the DOI as the key
for doi in input_df['DOI']:
    result = pd.DataFrame()
    print(f'\r{doi}', end='',)
    try:
        string_result = ''.join(getCellArticles(doi))
        for i in nlps:
            temp=collision_cleanup(i,(NER_results(nlps[i],string_result)))
            result = pd.concat([result, temp])
        result = result.sort_values(by=['start']).reset_index(drop=True)
        success[doi] = [string_result,result]
    except:
        fail.append([doi,i])

In [None]:
##Some simple clean up around em dashes.
for i in success:
    replace = []
    for j in success[i][1]['word']:
        replace.append(j.replace(" - ", "-"))
    success[i][1]['word'] = replace

In [None]:
#Generate classifications of histone modifiers. This part is optional and can be swapped out at will for any other classification of a gene product
reader = pd.read_csv('/NLP/readertbl.csv')
writer = pd.read_csv('/NLP/writertbl.csv')
eraser = pd.read_csv('/NLP/erasertbl.csv')
reader['classification']= 'reader'
writer['classification']= 'writer'
eraser['classification']= 'eraser'
classification_df = pd.concat([reader, writer,eraser]).reset_index(drop=True)
classification_df['gene'] = classification_df['gene'].str.lower()

##Scan through DF output and match with a knowledge base
for i in success:
    success[i][1]['classification'] = 'NULL'
    queries = (set(list(success[i][1]['word'])) & set(classification_df['gene'])) ##Speed up look up by creating sets 
    for query in queries:
        index = classification_df[classification_df['gene']==query].index.values
        classification = list(classification_df.loc[index, 'classification'])
        index_change = list(success[i][1][success[i][1]['word']==query].index.values)
        for j in index_change:
            success[i][1].loc[[j], 'classification'] = pd.Series([classification], index=success[i][1].index[[j]])

In [None]:
# with open('savedCellArticles.pkl', 'wb') as f:
#     pickle.dump(success, f)

In [5]:
with open('savedCellArticles.pkl', 'rb') as f:
    success = pickle.load(f)

In [None]:
success

In [None]:
##response_json['full-text-retrieval-response']['originalText']['xocs:doc']['xocs:serial-item']['article']['head'] 
##Bolded "abstract/intro" text and data availability statements

In [None]:
##response_json['full-text-retrieval-response']['originalText']['xocs:doc']['xocs:serial-item']['article']['tail'] 
##Citations


In [None]:
##response_json['full-text-retrieval-response']['originalText']['xocs:doc']['xocs:serial-item']['article']['body']['ce:appendices'] 
##No supplemental text itself, may not be necessary

In [None]:
#response_json['full-text-retrieval-response']['originalText']['xocs:doc']['xocs:serial-item']['article']['body']['ce:sections']['ce:section'][1]['ce:section'][0]
#Not all sections structured the same



In [6]:
article_list = []
for i in success:
    for j in success[i][1]['classification']:
        if isinstance(j, list):
            article_list.append(i)
result = [*set(article_list)]
sample = random.sample(result, 10)
            

In [7]:
sample

['10.1016/j.bbagrm.2018.10.019',
 '10.1016/j.beha.2004.08.011',
 '10.1016/j.jbior.2012.04.003',
 '10.1016/j.dnarep.2009.04.003',
 '10.1016/j.tig.2006.09.007',
 '10.1016/j.mce.2017.03.016',
 '10.1016/j.dnarep.2011.01.012',
 '10.1016/j.jmb.2008.09.011',
 '10.1016/j.currproblcancer.2018.03.001',
 '10.1016/j.ejmg.2019.103739']