# The knowledge network of COVID-19 through text mining of COVID-19 literature

### <font color='red'>A full version of the interactive network can be found on our website:</font> 

### [https://covid19.insilicom.com/](https://covid19.insilicom.com/)




## Table of Contents

1. [Introduction](#Introduction)

2. [The knowledge network](#The-knowledge-network)

3. [Import packages](#Import-packages)

4. [Configurations](#Configurations)

5. [Querying CORD19 dataset by keywords](#Querying-CORD19-dataset-by-keywords)

6. [Acquiring PubMed data](#Acquiring-some-PubMed-abstracts-for-generating-training-data-for-document-retrievel-and-background-edges-for-information-extraction)

7. [Acquiring PubMed papers](#Acquiring-PubMed-papers)

8. [Functions for text processing](#Functions-for-text-processing)

9. [Document retrieval](#Document-Retrieval)

10. [The BERT model](#The-BERT-model)

11. [Conclusion of document retrieval](#Conclusion-of-document-retrieval)

12. [Information extraction](#Information-extraction)

# Introduction

In the Kaggle Challenge, we aimed at building an automatic pipeline to extract knowledge from COVID-19 literature related to a topic of interest. The knowledge we will extract is represented as relationships between terms related to the topic of interest. The extracted knowledge can be visualized as an interactive network so that researchers can quickly explore and grasp existing knowledge by navigating in the network. Advanced knowledge discovery tools can be built on top of this framework, which will be explored in the next phase of the project.

The input to the pipeline is a set of keywords, called keywords of interest (KOI); the output of the pipeline is a set of relationships among terms related to the KOI, directly or indirectly, where each relationship is associated with a sentence from which the relationship was extracted from. The interactive network allows users to quickly see the related terms, their relationships, the corresponding sentences from which relationships were extracted, and the articles containing the sentences. Without reading a large number of articles, which can be very time-consuming, a user can quickly grasp the key knowledge related to the KOI using this tool.

Our pipeline consists of two modules: document retrieval and information extraction. In the first module, we first use KOI to retrieve a set of relevant papers containing the KOI. In addition to this standard approach, we have developed a deep learning based method to retrieve additional papers that do not contain KOI, but are still relevant to KOI.

To build a model that is able to retrieve articles that are relevant to KOI, but do not contain KOI, we need to generate quality training data. The true cases were obtained by querying PubMed using KOI and then removing KOI from the returned articles. The rationale is that even by removing KOI from an article, by reading the rest of the words from the article, a knowledgeable researcher can still judge that the article is relevant to the KOI. We have manually tested this assumption and found that it was quite reasonable. The negative cases were obtained by sampling PubMed articles randomly and making sure that returned articles do not contain KOI. We trained the deep learning model using standard word embedding and BERT model. The model trained using this data was designed to learn the pattern in the articles that are relevant to KOI, but do not contain KOI. When we apply the model to articles that do not contain KOI, it will rank them according to how they are related to KOI. The model achieved satisfactory accuracy in the training dataset. We have manually read some of the top articles retrieved from CORD-19, and found many of them are indeed relevant to the KOI.

To extract relationships relevant to KOI, we first extracted all the terms in relevant articles and consider two terms as related if they co-occur in the same sentence. Other more sophisticated approaches can be used (i.e. some of our previous relationship extraction studies). These relationships contain some trivial ones, which we would like to remove. To that end, we queried a large number of PubMed articles from PubMed using its API using the same KOI, but not related to COVID-19. We extracted relationships using the same co-occurrence approach and subtracted frequently occurring background relationships (they are likely the trivial ones we want to remove) from those obtained from COVID-19 articles retrieved in the first module. We further extracted terms that co-occur with KOI in the same sentences and allow users to select these terms in the network. These terms and the relationships involving them are likely more relevant to KOI. 

Below we will first show the images of the interactive networks, which are followed by our pipeline and some intermediate results.


## How to cite the work?
```
@inproceedings{The knowledge network of COVID-19 through text mining of COVID-19 literature,
	author = {Shubo Tian, Jian Wang, Yuhang Liu, Wanjing Wang, Chun-Chao Lo, Xiaodong Pang, Yuchuan Tao, Jinfeng Zhang},
	title = {The knowledge network of COVID-19 through text mining of COVID-19 literature},
	address = {Florida State University, Insilicom LLC, Tallahassee, FL, USA},
    year = {2020},
	url = {\url{https://covid19.insilicom.com/}},
}
```

## List of keywords for each task

In [None]:
# list of keywords for each task
task_keywords = {
    "task_1": "transmission, incubation, environmental stability, natural history, diagnostics, infection prevention, control, incubation period, asymptomatic shedding, asymptomatic transmission, Seasonality, charge distribution, adhesion to surfaces, environmental survival, Persistence, stability, Disease models, Immune response, immunity, personal protective equipment, PPE",
    "task_2": "risk factor, Smoking, pre-existing pulmonary disease, pulmonary disease, Co-infection, co-existing infection, co-morbidities, pregnancy, Socio-economic factor, behavioral factor, economic impact, Transmission dynamics, reproductive number, incubation period, serial interval, modes of transmission, environmental factor, Severity, high-risk, Public health mitigation measure",
    "task_3": "Genetics, evolution, genomic, genome, strain, field surveillance, genetic sequencing, host, animal host",
    "task_4": "Vaccine, therapeutic, drug, treatment, inhibitor, naproxen, clarithromycin, minocycline, viral replication, Antibody-Dependent Enhancement, ADE, animal model, antiviral agent, universal coronavirus vaccine, prophylaxis clinical",
    "task_5": "medical care, surge capacity, nursing home, resource allocation, personal protective equipment, PPE, process of care, clinical characterization, management, long term care, skilled nursing, surge medical staff, Age-adjusted mortality, Acute Respiratory Distress Syndrome, ARDS, organ failure, Extracorporeal membrane oxygenation, ECMO, mechanical ventilation, extrapulmonary manifestation, cardiomyopathy, cardiac arrest, regulatory standard, elastomeric respirator, telemedicine, hospital flow, workforce protection, workforce allocation, community-based support, supply chain management, clinical care, public health intervention, infection prevention, supportive interventions, adjunctive interventions",
    "task_6": "non-pharmaceutical intervention, non-pharmaceutical, NPI, school closure, travel ban, social distancing, compliance, economic impact",
    "task_7": "geographic variation, geographic, geographic spread, geographic mortality",
    "task_8": "Diagnostics, surveillance, mitigation measure, early detection, point-of-care test, point-of-care, rapid bed-side, rapid bed-side test, screening, evolutionary hosts, transmission host",
    "task_9": "Ethical, ethics, ethical principle, ethical issue, fear, anxiety, social media",
    "task_10": "information sharing, data sharing, inter-sectoral collaboration, data standards, nomenclature, risk communication, governmental public health, communicating with high-risk populations, community measure, equity consideration, inequity, data-gathering, standardized nomenclature, information-sharing, Risk communication, Misunderstanding, disadvantaged population, marginalized population, underrepresented minorities"
}

Choose a task

In [None]:
task_num = '8'
KOI = task_keywords[f'task_{task_num}'] 
input_path = 'https://covid19.insilicom.com/task' + task_num + '/'

# The knowledge network

In [None]:
from IPython.display import Image
import requests

In [None]:
url=input_path + 'figure/network1.png'
Image(requests.get(url).content, width=800, height=800)

### Mouseover a node will pop up its info and highlight connecting edges

In [None]:
url=input_path + 'figure/network2.png'
Image(requests.get(url).content, width=800, height=800)

### Clicking an edge will pop up all the sentences relevent to the edge

In [None]:
url=input_path + 'figure/network3.png'
Image(requests.get(url).content, width=800, height=800)

# Import packages

In [None]:
import os
from Bio import Entrez
from Bio import Medline
import random
import csv
import re
import json
import pandas as pd
import nltk
from nltk import sent_tokenize
from nltk import word_tokenize
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from itertools import combinations
import string
import fasttext
from sklearn.model_selection import train_test_split

# Configurations

Note that the switch for running BERT model is set to False because the BERT cannot be trained on Kaggle server. We have tested on our local machine. If you run the code on your local machine with a GPU, you should change it to True.

In [None]:
# directory of covid-19 dataset
datadir = '/kaggle/input/CORD-19-research-challenge'
metadata = 'metadata.csv'
subsetdata = ['biorxiv_medrxiv', 'comm_use_subset', 'noncomm_use_subset', 'custom_license']

# configuration for running all code or not
# It will take a long time to run the whole notebook. 
# If you run whole notebook, please change if_run_all to True
if_run_all = False

if if_run_all:
    covid19_dataset_path = 'covid19_datasets'
    if not os.path.exists(covid19_dataset_path): os.mkdir(covid19_dataset_path)
    pubmed_tiabs_path = 'pubmed_tiabs'
    if not os.path.exists(pubmed_tiabs_path): os.mkdir(pubmed_tiabs_path)
    bert_data_path = 'bert_datasets'
    if not os.path.exists(bert_data_path): os.mkdir(bert_data_path)
else:
    covid19_dataset_path = '/kaggle/input/covid19-processed-data'
    pubmed_tiabs_path = '/kaggle/input/covid19-processed-data'
    bert_data_path = '/kaggle/input/covid19-processed-data'

# configuration for BERT model
output_folder='.'
# We set the finetune BERT model as false because BERT model cannot be trained on Kaggle server. 
# If you run the code at your local machine, please change it to True
if_finetune_bert_model = False # if finetune BERT model or not

if if_finetune_bert_model==True:
    bert_model_name='bert_model'
else:
    bert_model_name='path_to_model/bert_model' #specify the path to the trained models

# configuration of others
edges_path = 'covid19_edges'
if not os.path.exists(edges_path): os.mkdir(edges_path)
# Entrez.email = "" # use your own email address
# Entrez.api_key = "" # use your own pubmed api key
stop_words = set(stopwords.words("english"))
wnlem = WordNetLemmatizer()
model = fasttext.load_model('/kaggle/input/fasttextmodel/lid.176.ftz')
random.seed(10000)

# Preprocessing COVID-19 Dataset
The COVID-19 open research dataset consists of papers from different sources including CZI, PMC, BioRxiv/MedRxiv. The dataset includes a metadata file and many files in json format. We start with combining the data in different formats into one dataframe. Then different keywords are used to query papers of different topics for downstream analysis.

## Combining the metadata and JSON Data
First, we combine the data in the metadata file and those in json format into one dataframe. We acquire all text including title, abstract, full text, bibliography entries, figures and table annotations, and others when available and organize them into the dataframe. This dataframe provides easier access to the dataset for downstream analysis, such as document retrieval and information extraction.

In [None]:
# function for developing the dictionary mapping location of full text files in json format to the coresponding paper
def get_fulltext_links(subsetdata):
    """
    Develop a dictionary mapping the filename to the location of full text
    json file if the json file exists.
    Args:
        subsetdata: a list of the directories for json files.
    """
    fulltextlinks = {}
    for subset in subsetdata:
        subsetdir = f'{datadir}/{subset}/{subset}'
        for root, dirs, _ in os.walk(subsetdir, topdown=False):
            for name in dirs:
                subsetfiles = os.listdir(f'{root}/{name}')
                for subsetfile in subsetfiles:
                    filename = subsetfile.split('.')[0]
                    fulltextlinks[filename] = f'{root}/{name}/{subsetfile}'
    return fulltextlinks

# function for combining all text of the papers together
def get_all_covid19_text(df, fulltextlinks):
    """
    This function combine all the text of the covid-19 dataset
    (the metadata file, and all json files) into one dataframe.
    Args:
        df: the dataframe to combine all text.
        fulltextlinks: the dictionary mapping papers to their json files. 
    """
    for idx, row in df.iterrows():
        if row['has_pdf_parse'] == True:
            full_text_file = fulltextlinks[row['sha'].split(';')[0]]
        elif row['has_pmc_xml_parse'] == True:
            full_text_file = fulltextlinks[row['pmcid']]
        else: continue
        jsonfile = json.load(open(full_text_file, 'rb'))
        title = jsonfile['metadata']['title']
        if 'abstract' in jsonfile:
            abstract = '\n'.join([text['text'] for text in jsonfile['abstract']])
        else: abstract = ''
        if 'body_text' in jsonfile:
            fulltext = '\n'.join([text['text'] for text in jsonfile['body_text']])
        else: fulltext = ''
        if jsonfile['bib_entries'] != None:
            bibs = '\n'.join([bib['title'] for bib in jsonfile['bib_entries'].values()])
        else: bibs = ''
        if jsonfile['ref_entries'] != None:
            figures = '\n'.join([ref['text'] for ref in jsonfile['ref_entries'].values() if ref['type'] == 'figure'])
            tables = '\n'.join([ref['text'] for ref in jsonfile['ref_entries'].values() if ref['type'] == 'table'])
        else:
            figures = ''
            tables =  ''
        if jsonfile['back_matter'] != None:
            othertext = '\n'.join([text['text'] for text in jsonfile['back_matter']])
        else: othertext = ''
        
        if df.loc[idx, 'title'] == '' and title != '': df.loc[idx, 'title'] = title
        if df.loc[idx, 'abstract'] == '' and abstract != '': df.loc[idx, 'abstract'] = abstract
        if fulltext != '': df.loc[idx, 'full_text'] = fulltext
        if bibs != '': df.loc[idx, 'bib_entries'] = bibs
        if figures != '': df.loc[idx, 'figures'] = figures
        if tables != '': df.loc[idx, 'tables'] = tables
        if othertext != '': df.loc[idx, 'other_text'] = othertext
    return df

In [None]:
# Preprocess the covid19 dataset
if if_run_all:
    fulltextlinks = get_fulltext_links(subsetdata)
    df = pd.read_csv(f'{datadir}/{metadata}', na_filter= False)
    df['url'] = df['url'].apply(lambda x: x if x != '' else 'None')
    df = df.assign(full_text = '', bib_entries = '', figures = '', tables = '', other_text = '')
    df = get_all_covid19_text(df, fulltextlinks)
    columns_to_keep = ['cord_uid', 'title', 'abstract', 'authors', 'journal', 'publish_time', 'url']
    df = df[columns_to_keep]
    df.to_csv(f'{covid19_dataset_path}/covid19_dataset_alldata.csv', index = False)
else:
    df = pd.read_csv(f'{covid19_dataset_path}/covid19_dataset_alldata.csv', na_filter= False)

## Querying CORD19 dataset by keywords

We first implement a function to search papers in the COVID-19 dataset relevant to KOI. The function can be used to search papers containing different keywords. Using the function we are able to search papers containing a specific keyword or a list of keywords.

In [None]:
# function for covid19 data subset by key word phrase 
def covid19_data_subset(df, phrases, if_and = False):
    """
    Subset papers containing any keywords in the phrases by
    returning a list of indexes of the subset papers in a dataframe.
    Args:
        df: the dataframe containing all papers.
        phrases: a string of keywords split by ',' to subset the papers in df.
        if_ann: if True, subset data containing all keywords in the phrases,
                if False, subset data containing  any keyword in the phrase.
    """
    assert isinstance(phrases, str) and phrases != '' 
    phrases = [phrase.split() for phrase in phrases.lower().split(',')]
    and_indexes = set()
    or_indexes = set()
    for idx, row in df.iterrows():
        text = ' '.join([row.title, row.abstract])
        text = text.lower()
        count_phrase = 0
        for phrase in phrases:
            count_word = 0
            for word in phrase:
                if word in text: count_word += 1
            if count_word == len(phrase):
                count_phrase += 1
                or_indexes.add(idx)
        if count_phrase == len(phrases): and_indexes.add(idx)
    if if_and:
        return list(and_indexes)
    return list(or_indexes)

We used 'risk factor' as a keyword to search papers related to the topic on risk factors about COVID-19 in the dataset. All papers containing the words of risk and factor in title and abstract were acquired to deveplop a subset data for further research on the risk factor topic. Similarly, we used 'vaccine' and 'therapeutic' as keywords to search papers for further research of topics on vaccines and therapeutics respectively.

In [None]:
# covid19 subset for risk factor
if if_run_all:
    df_positive_index = covid19_data_subset(df, 'risk factor')
    df_positive = df.loc[df_positive_index]
    df_negative = df.drop(df_positive_index)
    df_positive.to_csv(f'{covid19_dataset_path}/covid19_dataset_risk_positive.csv', index = False)
    df_negative.to_csv(f'{covid19_dataset_path}/covid19_dataset_risk_negative.csv', index = False)

Make a covid19 subset for a task

In [None]:
# covid19 subset for a task
if if_run_all:
    df_subset = df.loc[covid19_data_subset(df, KOI)]
    # df_subset.to_csv(f'{covid19_dataset_path}/covid19_subset_task_{task_num}.csv', index = False)

# Acquiring some PubMed abstracts for generating training data for document retrievel and background edges for information extraction
The PubMed database contains more than 30 million publications of biomedical research. We will download a subset of PubMed abstracts relevant to the KOI to generate data for training deep learning models as described in Introduction. The abstracts will also be used for generating background edges for the second module, information extraction.

In addition to abstracts relevant to the KOI, we also randomly sampled some abstracts to serve as negative samples in deep learning model training. We limited the publication time to between Jan. 1, 2009 and Dec. 31, 2018. 


In [None]:
# Function for getting pmids for a query
def get_query_pmids(query, max_ret = 10000,
                    start_date = '2009/01/01', end_date = '2018/12/31'):
    """
    Return a list of pmids for papers published in a period and
    containing keywords by query to the pubmed database through Entrez.
    Args:
        query: keywords for searching the pubmed database.
        max_ret: maximum number of pmids returned in one query.
    """
    query = ' OR '.join(query.lower().split(','))
    search_results = Entrez.read(Entrez.esearch(
        db = "pubmed", term = query, retmax = max_ret,
        mindate = start_date, maxdate = end_date, datetype = "pdat"
        ))
    counts = int(search_results["Count"])
    idlist = set(search_results["IdList"])
    if counts > max_ret:
        for start in range(max_ret, counts, max_ret):
            search_results = Entrez.read(Entrez.esearch(
                db = "pubmed", term = query, retstart = start, retmax = max_ret,
                mindate = start_date, maxdate = end_date, datetype = "pdat"
                ))
            idlist = idlist | set(search_results["IdList"])
    return list(idlist)

# Function for writing pmids list to text file
def write_pmids_list(pmids_list, file_name):
    assert isinstance(file_name, str)
    with open(f'{pubmed_tiabs_path}/{file_name}.txt', 'w', encoding = 'utf8') as f:
        for line in pmids_list:
            f.write(line+'\n')

In [None]:
# Generate the pmid list of all papers between 2009/01/01-2018/12/31 on pubmed 
if if_run_all:
    query = ""
    pmids_all = get_query_pmids(query)
    # write_pmids_list(pmids_all, 'pmids_all')

In [None]:
# Generate the pmid list for pubmed papers containing risk factor
if if_run_all:
    query = "risk[TIAB] AND factor[TIAB] OR factors[TIAB]"
    pmids_positive = get_query_pmids(query)
    #write_pmids_list(pmids_positive, 'pmids_risk')

Query list of pmids for pubmed papers contianing keywords of a task

In [None]:
# Query list of pmids for pubmed papers contianing KOI of a task
if if_run_all:
    pmids_list = get_query_pmids(KOI)
    # write_pmids_list(pmids_list, f'pmids_task_{task_num}')

## Acquiring PubMed papers
After a list of pmids was returned for a given KOI, we randomly selected 100,000 PMIDs from the list. Then we retrieve the titles and abstracts of the papers with the selected PMIDs from PubMed. The following functions read the list of PMIDs and retrieve paper titles and abstracts from the PubMed database.

In [None]:
# Function for loading list of pmids
def load_pmids_list(pmids_file):
    """
    Return a list of pmids from a text file.
    Args:
        pmids_file: a text file containing pmids with one pmid in each line.
    """
    assert isinstance(pmids_file, str)
    with open(f'{pubmed_tiabs_path}/{pmids_file}.txt', 'r', encoding = 'utf8') as f:
        pmids_list = [line.strip() for line in f]
    return pmids_list

# Function for getting titles and abstracts of a list of less than 10000 pmids
def get_pubmed_tiabs(pmids_list, max_len = 10000):
    """
    Return a dictionary of title and abstract of papers by pmid
    in the pmids_list in one query of the pubmed database.
    Args:
        pmids_list: a list of pmids to query for the title and abstract of papers.
    """
    assert isinstance(pmids_list, list) and len(pmids_list) <= max_len
    pubmed_tiabs = {}
    records =  Medline.parse(Entrez.efetch(
        db="pubmed", id=pmids_list, rettype='medline', retmode='text'
        ))
    for record in records:
        if 'PMID' not in record: continue
        pubmed_tiabs[record['PMID']] = {'title': record.get('TI', ''),
                                        'abstract': record.get('AB', '')}
    return pubmed_tiabs

# Function for getting titles and abs of a list of more than 10000 pmids and write to csv files
def get_pubmed_tiabs_all(pmids_list, batch = 10000):
    """
    Acquire title and abstract of papers in a list of pmids by query the pubmed database
    and write the pmid, title and abstract to a csv file.
    Args:
        pmids_list: a list of pmids to query for the title and abstract of papers.
        file_name: a string as name of the csv file
    """
    assert isinstance(pmids_list, list)
    len_list = len(pmids_list)
    df_tiabs = {'pmid':[], 'title':[], 'abstract':[]}
    for i in range(0, len_list, batch):
        pmid_list = pmids_list[i:i+batch] if i+batch <= len_list else pmids_list[i:]
        pubmed_tiabs = get_pubmed_tiabs(pmid_list, max_len = batch)
        for k,v in pubmed_tiabs.items():
            df_tiabs['pmid'].append(k)
            df_tiabs['title'].append(v['title'])
            df_tiabs['abstract'].append(v['abstract'])
    return pd.DataFrame(data=df_tiabs)

Acquiring pubmed papers for background edges of a task

In [None]:
# acquiring pubmed papers for background edges of a task
if if_run_all:
    # pmids_list = load_pmids_list(f'pmids_task_{task_num}')
    pmids_list = random.sample(pmids_list, 100000) if len(pmids_list) > 100000 else pmids_list
    df_tiabs = get_pubmed_tiabs_all(pmids_list)
    # df_tiabs.to_csv(f'{pubmed_tiabs_path}/covid19_tiabs_task_{task_num}.csv', index = False)

## Building the PMID lists for positive and negative samples

Firstly, we use load_pmids_list function to get the PMIDs of positive papers (the papers which contains KOI in it). 

Then, we get all the PMIDs from '2009/01/01' to '2018/12/31' and remove the PMIDs of positive papers from it to build the negative samples, which are papers not relevant to KOI. A small number of them are actually relevant, but such noise should be fine.

In [None]:
# acquiring pubmed papers relevant to KOI for BERT training
if if_run_all:
    # pmids_all = load_pmids_list('pmids_all')
    # pmids_positive = load_pmids_list('pmids_risk')
    # delete the intersection of pmids_all and pmids_positive to get the negative pmids
    pmids_negative = list(set(pmids_all) - set(pmids_positive))

    random.seed(10000)
    # Randomly get 10000 samples of positive papers
    pmids_positive = random.sample(pmids_positive, 10000)
    get_pubmed_tiabs_all(pmids_positive).to_csv(f'{pubmed_tiabs_path}/pubmed_tiabs_risk_positive.csv', index = False)
    # Randomly get 30000 samples of negative papers
    pmids_negative = random.sample(pmids_negative, 30000)
    get_pubmed_tiabs_all(pmids_negative).to_csv(f'{pubmed_tiabs_path}/pubmed_tiabs_risk_negative.csv', index = False)

# Functions for text processing
After all data was retrieved from the COVID-19 dataset and PubMed, we can proceed to the second module, information extraction. Conducting research on text data requires extensive processing of the data. We developed several functions as follows for necessary text processing such as denoising, normalization and tokenization.

In [None]:
# Function for text preprocess
def text_preprocess(text):
    """
    Preprocess a string of text including title and abstract.
    Args:
        text: a string of text.
    """
    assert isinstance(text, str)
    text = re.sub(r'[\r\n]', ' ', text)
    text = re.sub('(?<![A-Z])\.(?=[A-Z])', '. ', text)
    text = re.sub(r"\(\w+[^\(\)]+ et al\.?.*?\)",'',text)
    text = re.sub('((?<!\d)|(?<=\d{4}))\.[\[\(]?(\d+|\d+-\d+)((,|, )(\d+|\d+-\d+))*[\)\]]?([\sA-Z]|$)', '. ', text)
    text = re.sub('(?<=\s)\[\d+\](\s\[\d+\])*\s|\[\d+(,\s\d+)*?\]\s?', '', text)
    text = re.sub(r"\(https?://.*?\)|\(?https?://.*?([\s,]|$)",'',text)
    text = re.sub(r"\(doi:.*?\)|\(?doi:.*?([\s,]|$)",'',text)
    text = re.sub(r'[\t\|/]', ' ', text)
    text = re.sub('–|‐', '-', text)
    text = re.sub('·', '.', text)
    return text

# Function for cleaning sentence
def sentence_preprocess(text):
    """
    Preprocess a string of text in a sentence.
    Args:
        text: a string of text.
    """
    assert isinstance(text, str)
    text = re.sub(r"^\[|\](?=\.$)|^\d+\s+(?=[A-Z])", '', text)
    text = re.sub(r"^\[.*?\]\s+(?=[A-Z])", '', text)
    text = re.sub(r"^[\w\s-]*?:\s*(?=[A-Zo])", '', text)
    text = re.sub(r"^\((.*)\)(.$)", r"\1\2", text)
    text = re.sub(r"^(Publisher Summary|Material and methods)\s*(?=[A-Z])", '', text)
    text = re.sub(r"^(((Abstract|ABSTRACT|Background:?\]?|Methods?\]?|METHODS?|Objectives?|Aims?|Summary|SUMMARY|Conclusions?\]?|Findings?|Results?\]?|Discussion)\s*)+)(?=[A-Z4])", '', text)
    return text

# Function for tokenizing sentence into words
def word_tokens(text):
    """
    Tokenize a string of text into a list of lemmatized words.
    Args:
        text: a string of text.
    """
    assert isinstance(text, str)
    tokens = word_tokenize(text.lower())
    tokens = [w for w in tokens if w not in string.punctuation]
    tokens = [w for w in tokens if not re.match(r'\d*,?\d*[\.-]?\d*$|\d*,?\d*\.?\d*-(\d*,?\d*\.?\d*)?$', w)]
    tokens = [w for w in tokens if re.match(r'^[\w\d]', w)]
    tokens = [w for w in tokens if w not in stop_words]
    tokens = [wnlem.lemmatize(w) for w in tokens]
    tokens = [w for w in tokens if len(w) > 2]
    return tokens

# Function for tokenizing sentence into noun words
def tagged_word_tokens(text):
    """
    Tokenize a string of text into a list of lemmatized words,
    return a list of words of noun.
    Args:
        text: a string of text.
    """
    assert isinstance(text, str)
    tokens = word_tokenize(text)
    tokens = nltk.pos_tag(tokens)
    tokens = [(w.lower(), t) for w, t in tokens]
    tokens = [w for w, t in tokens if t in ['NN', 'NNS', 'NNP', 'NNPS']]
    tokens = [w for w in tokens if w not in string.punctuation]
    tokens = [w for w in tokens if not re.match(r'\d*,?\d*[\.-]?\d*$|\d*,?\d*\.?\d*-(\d*,?\d*\.?\d*)?$', w)]
    tokens = [w for w in tokens if re.match(r'^[\w\d]', w)]
    tokens = [w for w in tokens if w not in stop_words]
    tokens = [wnlem.lemmatize(w) for w in tokens]
    tokens = [w for w in tokens if len(w) > 2]
    return tokens

# Document Retrieval

## Data Preparation

Our goal is to find the papers which are relevant to KOI, but do not contain the KOI. Therefore, we will delete the words in KOI from the papers to form the positive samples of our training data. We will randomly choose papers using PMIDs, which do not appear in the positive paper list, as negative samples of the training data.

In [None]:
# Function for processing data for information retrieval model
def model_data_process(df, phrases = ''):
    """
    Process text of titles and abstracts in a dataframe
    by removing keywords in the phrases and punctuations.
    Args:
        df: a dataframe of titles and abstracts by ids.
        phrases: a list of keywords split by ',' to be removed from the titles and abstracts.
    """
    assert isinstance(phrases, str)
    colnames = list(df.columns)
    df = df.rename(columns = {colnames[0]:'pid'})
    if phrases != '': phrases = [phrase.split() for phrase in phrases.split(',')]
    for idx, row in df.iterrows():
        title = text_preprocess(row.title)
        abstract = text_preprocess(row.abstract)
        tokens_title = word_tokenize(title)
        tokens_abstract = word_tokenize(abstract)
        if phrases != '':
            for phrase in phrases:
                count_in_title = 0
                count_in_abstract = 0
                for word in phrase:
                    if word in title.lower(): count_in_title += 1
                    if word in abstract.lower(): count_in_abstract += 1
                if count_in_title == len(phrase):
                    tokens_title = [w for w in tokens_title if phrase[0] not in wnlem.lemmatize(w.lower()) ]
                if count_in_abstract == len(phrase):
                    tokens_abstract = [w for w in tokens_abstract if phrase[0] not in wnlem.lemmatize(w.lower())]
        title = ' '.join([w for w in tokens_title]) # if w not in string.punctuation
        abstract = ' '.join([w for w in tokens_abstract]) # if w not in string.punctuation
        df.loc[idx, 'title'] = title
        df.loc[idx, 'abstract'] = abstract
    return df

## Building training dataset
Read positive papers and delete the KOI to build the positive training data
<br>
Read the negative papers to build the negative training data
<br>
Merge the positive and negative cases and build training and validation data with the ratio 4:1
<br>
Save the result

In [None]:
# training datasets
if if_run_all:
    df_ir = pd.read_csv(f'{pubmed_tiabs_path}/pubmed_tiabs_risk_positive.csv', na_filter= False)
    df_positive = model_data_process(df_ir, 'risk,factor')
    df_positive['label'] = [1] * df_positive.shape[0]

    df_ir = pd.read_csv(f'{pubmed_tiabs_path}/pubmed_tiabs_risk_negative.csv', na_filter= False)
    df_negative = model_data_process(df_ir)
    df_negative['label'] = [-1] * df_negative.shape[0]

    df_ir = pd.concat([df_positive, df_negative], ignore_index=True)
    train_sample, test_sample = train_test_split(df_ir, test_size=0.20, random_state=42, shuffle=True)
    train_sample.to_csv(f'{bert_data_path}/train_sample.csv', index = False)
    test_sample.to_csv(f'{bert_data_path}/test_sample.csv', index = False)

## Building prediction dataset
Do the same thing to test sample as above. 
<br>
We used CORD-19 data that do not contain KOI as the testset. 
<br>
We will use the BERT model to find the papers that are relevant to KOI in the testset.

In [None]:
# For building prediction dataset
# Do the same thing to test sample as above. We use all the papers in four folders of kaggle Covid as test sample. 
# To model_negative samples, we will therefore use bert model, which will be discussed in the following part,
# to find the papers that talk about the keyword.
# prediction dataset
if if_run_all:
    df_ir = pd.read_csv(f'{covid19_dataset_path}/covid19_dataset_risk_positive.csv', na_filter= False)
    df_ir = df_ir[['cord_uid', 'title', 'abstract']]
    df_positive = model_data_process(df_ir, 'risk, factor')
    df_positive['label'] = [1] * df_positive.shape[0]
    df_positive.to_csv(f'{bert_data_path}/model_positive.csv', index = False)

    df_ir = pd.read_csv(f'{covid19_dataset_path}/covid19_dataset_risk_negative.csv', na_filter= False)
    df_ir = df_ir[['cord_uid', 'title', 'abstract']]
    df_negative = model_data_process(df_ir)
    df_negative['label'] = [-1] * df_negative.shape[0]
    df_negative.to_csv(f'{bert_data_path}/model_negative.csv', index = False)

# The BERT model
BERT (Bidirectional Encoder Representations from Transformers) model (https://github.com/google-research/bert) is a method of pre-training language representations. 

The whole proecess consists of the following steps:

```
1) Data pre-processing
2) BERT model preparation
3) Model building by adding layer for classification
4) Model training
5) Prediction
```

###  load packages

In [None]:
# load package 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import unicodedata

import six

import sentencepiece as spm
# Some ideas are from https://www.kaggle.com/gunesevitan/nlp-with-disaster-tweets-eda-cleaning-and-bert/notebook
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow.keras.backend as K

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.models import load_model
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score, recall_score, f1_score
from tensorflow.keras.utils import to_categorical
import os
import re
import string
from nltk.corpus import stopwords

from nltk.tokenize import word_tokenize

## Data Pre-processing

Including the following steps:
```
- Convert input text to BERT model format
- Map Greek alphabets to English and abbreviations 
- Clean up text
```

In [None]:
# convert input text to BERT format before train model
def convert_text_format(df,tokenizer,max_seq_length):
    # Covert texts format into BERT model format
    
    texts_series=df['title']+' '+df['abstract']
    texts_series=texts_series.apply(clean_texts)
    texts=texts_series.values
    texts_input_ids=[]
    texts_input_masks=[]
    texts_segment_ids=[]
    max_len_temp=0
    for i,text in enumerate(texts):
    
        tokens_text = tokenizer.tokenize(text)
        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_text) > max_seq_length - 2:
            tokens_text = tokens_text[0:(max_seq_length - 2)]
        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_text:
            tokens.append(token)
            segment_ids.append(0)
        
        tokens.append("[SEP]")
        segment_ids.append(0)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
    
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_masks = [1] * len(input_ids)
        max_len_temp=max(max_len_temp,len(input_ids))
        
        
      # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_masks.append(0)
            segment_ids.append(0)
        
        assert len(input_ids) == max_seq_length
        assert len(input_masks) == max_seq_length
        assert len(segment_ids) == max_seq_length
        if i < 5:
            print("*** Example ***")
            
            print("tokens: ",
                         " ".join([x for x in tokens]))
            print("input_ids: ", " ".join([str(x) for x in input_ids]))
            print("input_masks: ", " ".join([str(x) for x in input_masks]))
            print("segment_ids: ", " ".join([str(x) for x in segment_ids]))
        if i%10000==0:
            print(i)
        texts_input_ids.append(input_ids)
        texts_input_masks.append(input_masks)
        texts_segment_ids.append(segment_ids)
    
    print("Longest tokens: %d"%max_len_temp)

    return [np.array(texts_input_ids).astype(np.int32), 
                np.array(texts_input_masks).astype(np.int32), 
                np.array(texts_segment_ids).astype(np.int32)]

# preproccesing 
def greek_to_eng(c):
    # map greek alphabets to English
    d=greek_alphabet.get(c)
    if d!=None:
        c=d
    return c

# Some codes are from https://www.kaggle.com/rftexas/text-only-kfold-bert

def remove_emoji(text):
    # remove emoji
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F"  # emoticons
                           u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                           u"\U0001F680-\U0001F6FF"  # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    return emoji_pattern.sub(r'', text)



def clean_texts(text,remove_stopwords=False,remove_punc=True,more_clean=True,lower=True):
    # Clean the texts.

    # Convert words to lower case and split them
    if lower:
        text = text.lower().split()
    else:
        text = text.split()

#    # Optionally, remove stop words
    if remove_stopwords:
        stops = set(stopwords.words("english"))
        text = [w for w in text if not w in stops]
    
    text = " ".join(text)
    # Clean the text
    
    text = remove_emoji(text)

    
    
    # Remove url (https://stackoverflow.com/questions/3809401/what-is-a-good-regular-expression-to-match-a-url)
    text = re.sub(r"https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)", "", text)
    

    # Map prime
    text = re.sub(r"what's", "what is ", text)
    text = re.sub(r"don\x89Ûªt", "do not", text)
    text = re.sub(r"i\x89Ûªm", "i am", text)
    text = re.sub(r"you\x89Ûªve", "you have", text)
    text = re.sub(r"it\x89Ûªs", "it is", text)
    text = re.sub(r"doesn\x89Ûªt", "does not", text)
    text = re.sub(r"i\x89Ûªve", "i have", text)
    text = re.sub(r"can\x89Ûªt", "cannot", text)
    text = re.sub(r"wouldn\x89Ûªt", "would not", text)
    text = re.sub(r"that\x89Ûªs", "that is", text)
    text = re.sub(r"\'ve", " have ", text)
    text = re.sub(r"can't", "cannot ", text)
    text = re.sub(r"can’t", "cannot ", text)
    text = re.sub(r"n't", " not ", text)
    text = re.sub(r"i'm", "i am ", text)
    text = re.sub(r"\'re", " are ", text)
    text = re.sub(r"\'d", " would ", text)
    text = re.sub(r"\'ll", " will ", text)


    text = text.split()
    text = " ".join(text)
    text = [greek_to_eng(w) for w in text] # map greek
    text = ''.join([x for x in text if x in string.printable])
    if more_clean:
        text = convert_abbrev_in_text(text) # map abbrev
    if remove_punc:
        punc=string.punctuation
        for p in punc:
            text = text.replace(p, '') # remove punctuation
    text = text.split()
    text = " ".join(text)   
    
    return text

def convert_abbrev(word):
    # map abbrev
    return abbreviations[word.lower()] if word.lower() in abbreviations.keys() else word


def convert_abbrev_in_text(text):
    # map abbrev
    tokens = word_tokenize(text)
    tokens = [convert_abbrev(word) for word in tokens]
    text = ' '.join(tokens)
    return text


# map Greek alphabets to English
greek_alphabet = {
    u'\u0391': 'Alpha',
    u'\u0392': 'Beta',
    u'\u0393': 'Gamma',
    u'\u0394': 'Delta',
    u'\u0395': 'Epsilon',
    u'\u0396': 'Zeta',
    u'\u0397': 'Eta',
    u'\u0398': 'Theta',
    u'\u0399': 'Iota',
    u'\u039A': 'Kappa',
    u'\u039B': 'Lamda',
    u'\u039C': 'Mu',
    u'\u039D': 'Nu',
    u'\u039E': 'Xi',
    u'\u039F': 'Omicron',
    u'\u03A0': 'Pi',
    u'\u03A1': 'Rho',
    u'\u03A3': 'Sigma',
    u'\u03A4': 'Tau',
    u'\u03A5': 'Upsilon',
    u'\u03A6': 'Phi',
    u'\u03A7': 'Chi',
    u'\u03A8': 'Psi',
    u'\u03A9': 'Omega',
    u'\u03B1': 'alpha',
    u'\u03B2': 'beta',
    u'\u03B3': 'gamma',
    u'\u03B4': 'delta',
    u'\u03B5': 'epsilon',
    u'\u03B6': 'zeta',
    u'\u03B7': 'eta',
    u'\u03B8': 'theta',
    u'\u03B9': 'iota',
    u'\u03BA': 'kappa',
    u'\u03BB': 'lamda',
    u'\u03BC': 'mu',
    u'\u03BD': 'nu',
    u'\u03BE': 'xi',
    u'\u03BF': 'omicron',
    u'\u03C0': 'pi',
    u'\u03C1': 'rho',
    u'\u03C3': 'sigma',
    u'\u03C4': 'tau',
    u'\u03C5': 'upsilon',
    u'\u03C6': 'phi',
    u'\u03C7': 'chi',
    u'\u03C8': 'psi',
    u'\u03C9': 'omega',
}

# map abbreviations (from https://www.kaggle.com/rftexas/text-only-kfold-bert)
abbreviations = {
    "$" : " dollar ",
    "€" : " euro ",
    "4ao" : "for adults only",
    "a.m" : "before midday",
    "a3" : "anytime anywhere anyplace",
    "aamof" : "as a matter of fact",
    "acct" : "account",
    "adih" : "another day in hell",
    "afaic" : "as far as i am concerned",
    "afaict" : "as far as i can tell",
    "afaik" : "as far as i know",
    "afair" : "as far as i remember",
    "afk" : "away from keyboard",
    "app" : "application",
    "approx" : "approximately",
    "apps" : "applications",
    "asap" : "as soon as possible",
    "asl" : "age, sex, location",
    "atk" : "at the keyboard",
    "ave." : "avenue",
    "aymm" : "are you my mother",
    "ayor" : "at your own risk", 
    "b&b" : "bed and breakfast",
    "b+b" : "bed and breakfast",
    "b.c" : "before christ",
    "b2b" : "business to business",
    "b2c" : "business to customer",
    "b4" : "before",
    "b4n" : "bye for now",
    "b@u" : "back at you",
    "bae" : "before anyone else",
    "bak" : "back at keyboard",
    "bbbg" : "bye bye be good",
    "bbc" : "british broadcasting corporation",
    "bbias" : "be back in a second",
    "bbl" : "be back later",
    "bbs" : "be back soon",
    "be4" : "before",
    "bfn" : "bye for now",
    "blvd" : "boulevard",
    "bout" : "about",
    "brb" : "be right back",
    "bros" : "brothers",
    "brt" : "be right there",
    "bsaaw" : "big smile and a wink",
    "btw" : "by the way",
    "bwl" : "bursting with laughter",
    "c/o" : "care of",
    "cet" : "central european time",
    "cf" : "compare",
    "cia" : "central intelligence agency",
    "csl" : "can not stop laughing",
    "cu" : "see you",
    "cul8r" : "see you later",
    "cv" : "curriculum vitae",
    "cwot" : "complete waste of time",
    "cya" : "see you",
    "cyt" : "see you tomorrow",
    "dae" : "does anyone else",
    "dbmib" : "do not bother me i am busy",
    "diy" : "do it yourself",
    "dm" : "direct message",
    "dwh" : "during work hours",
    "e123" : "easy as one two three",
    "eet" : "eastern european time",
    "eg" : "example",
    "embm" : "early morning business meeting",
    "encl" : "enclosed",
    "encl." : "enclosed",
    "etc" : "and so on",
    "faq" : "frequently asked questions",
    "fawc" : "for anyone who cares",
    "fb" : "facebook",
    "fc" : "fingers crossed",
    "fig" : "figure",
    "fimh" : "forever in my heart", 
    "ft." : "feet",
    "ft" : "featuring",
    "ftl" : "for the loss",
    "ftw" : "for the win",
    "fwiw" : "for what it is worth",
    "fyi" : "for your information",
    "g9" : "genius",
    "gahoy" : "get a hold of yourself",
    "gal" : "get a life",
    "gcse" : "general certificate of secondary education",
    "gfn" : "gone for now",
    "gg" : "good game",
    "gl" : "good luck",
    "glhf" : "good luck have fun",
    "gmt" : "greenwich mean time",
    "gmta" : "great minds think alike",
    "gn" : "good night",
    "g.o.a.t" : "greatest of all time",
    "goat" : "greatest of all time",
    "goi" : "get over it",
    "gps" : "global positioning system",
    "gr8" : "great",
    "gratz" : "congratulations",
    "gyal" : "girl",
    "h&c" : "hot and cold",
    "hp" : "horsepower",
    "hr" : "hour",
    "hrh" : "his royal highness",
    "ht" : "height",
    "ibrb" : "i will be right back",
    "ic" : "i see",
    "icq" : "i seek you",
    "icymi" : "in case you missed it",
    "idc" : "i do not care",
    "idgadf" : "i do not give a damn fuck",
    "idgaf" : "i do not give a fuck",
    "idk" : "i do not know",
    "ie" : "that is",
    "i.e" : "that is",
    "ifyp" : "i feel your pain",
    "IG" : "instagram",
    "iirc" : "if i remember correctly",
    "ilu" : "i love you",
    "ily" : "i love you",
    "imho" : "in my humble opinion",
    "imo" : "in my opinion",
    "imu" : "i miss you",
    "iow" : "in other words",
    "irl" : "in real life",
    "j4f" : "just for fun",
    "jic" : "just in case",
    "jk" : "just kidding",
    "jsyk" : "just so you know",
    "l8r" : "later",
    "lb" : "pound",
    "lbs" : "pounds",
    "ldr" : "long distance relationship",
    "lmao" : "laugh my ass off",
    "lmfao" : "laugh my fucking ass off",
    "lol" : "laughing out loud",
    "ltd" : "limited",
    "ltns" : "long time no see",
    "m8" : "mate",
    "mf" : "motherfucker",
    "mfs" : "motherfuckers",
    "mfw" : "my face when",
    "mofo" : "motherfucker",
    "mph" : "miles per hour",
    "mr" : "mister",
    "mrw" : "my reaction when",
    "ms" : "miss",
    "mte" : "my thoughts exactly",
    "nagi" : "not a good idea",
    "nbc" : "national broadcasting company",
    "nbd" : "not big deal",
    "nfs" : "not for sale",
    "ngl" : "not going to lie",
    "nhs" : "national health service",
    "nrn" : "no reply necessary",
    "nsfl" : "not safe for life",
    "nsfw" : "not safe for work",
    "nth" : "nice to have",
    "nvr" : "never",
    "nyc" : "new york city",
    "oc" : "original content",
    "og" : "original",
    "ohp" : "overhead projector",
    "oic" : "oh i see",
    "omdb" : "over my dead body",
    "omg" : "oh my god",
    "omw" : "on my way",
    "p.a" : "per annum",
    "p.m" : "after midday",
    "pm" : "prime minister",
    "poc" : "people of color",
    "pov" : "point of view",
    "pp" : "pages",
    "ppl" : "people",
    "prw" : "parents are watching",
    "ps" : "postscript",
    "pt" : "point",
    "ptb" : "please text back",
    "pto" : "please turn over",
    "qpsa" : "what happens", #"que pasa",
    "ratchet" : "rude",
    "rbtl" : "read between the lines",
    "rlrt" : "real life retweet", 
    "rofl" : "rolling on the floor laughing",
    "roflol" : "rolling on the floor laughing out loud",
    "rotflmao" : "rolling on the floor laughing my ass off",
    "rt" : "retweet",
    "ruok" : "are you ok",
    "sfw" : "safe for work",
    "sk8" : "skate",
    "smh" : "shake my head",
    "sq" : "square",
    "srsly" : "seriously", 
    "ssdd" : "same stuff different day",
    "tbh" : "to be honest",
    "tbs" : "tablespooful",
    "tbsp" : "tablespooful",
    "tfw" : "that feeling when",
    "thks" : "thank you",
    "tho" : "though",
    "thx" : "thank you",
    "tia" : "thanks in advance",
    "til" : "today i learned",
    "tl;dr" : "too long i did not read",
    "tldr" : "too long i did not read",
    "tmb" : "tweet me back",
    "tntl" : "trying not to laugh",
    "ttyl" : "talk to you later",
    "u" : "you",
    "u2" : "you too",
    "u4e" : "yours for ever",
    "utc" : "coordinated universal time",
    "w/" : "with",
    "w/o" : "without",
    "w8" : "wait",
    "wassup" : "what is up",
    "wb" : "welcome back",
    "wtf" : "what the fuck",
    "wtg" : "way to go",
    "wtpa" : "where the party at",
    "wuf" : "where are you from",
    "wuzup" : "what is up",
    "wywh" : "wish you were here",
    "yd" : "yard",
    "ygtr" : "you got that right",
    "ynk" : "you never know",
    "zzz" : "sleeping bored and tired"
}




## BERT model preparation
Including the following steps:
```
- Tokenization classes
- Load BERT model
```

#### Tokenization classes

Use the code from BERT website
###### Cite: [Tokenization Code ](https://github.com/google-research/bert/blob/master/tokenization.py)

In [None]:
"""Tokenization classes implementation.

The file is forked from:
https://github.com/google-research/bert/blob/master/tokenization.py.
"""
SPIECE_UNDERLINE = "▁"

def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")


        
def printable_text(text):
    """Returns text encoded in a way suitable for print or `tf.logging`."""

    # These functions want `str` for both Python2 and Python3, but in one case
    # it's a Unicode string and in the other it's a byte string.
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text
        elif isinstance(text, unicode):
            return text.encode("utf-8")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")


def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    index = 0
    with tf.io.gfile.GFile(vocab_file, "r") as reader:
        while True:
            token = convert_to_unicode(reader.readline())
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab


def convert_by_vocab(vocab, items):
    """Converts a sequence of [tokens|ids] using the vocab."""
    output = []
    for item in items:
        output.append(vocab[item])
    return output


def convert_tokens_to_ids(vocab, tokens):
    return convert_by_vocab(vocab, tokens)


def convert_ids_to_tokens(inv_vocab, ids):
    return convert_by_vocab(inv_vocab, ids)


def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a piece of text."""
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens


class FullTokenizer(object):
    """Runs end-to-end tokenziation."""

    def __init__(self, vocab_file, do_lower_case=True, split_on_punc=True):
        self.vocab = load_vocab(vocab_file)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.basic_tokenizer = BasicTokenizer(
                do_lower_case=do_lower_case, split_on_punc=split_on_punc)
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

    def tokenize(self, text):
        split_tokens = []
        for token in self.basic_tokenizer.tokenize(text):
            for sub_token in self.wordpiece_tokenizer.tokenize(token):
                split_tokens.append(sub_token)

        return split_tokens

    def convert_tokens_to_ids(self, tokens):
        return convert_by_vocab(self.vocab, tokens)

    def convert_ids_to_tokens(self, ids):
        return convert_by_vocab(self.inv_vocab, ids)


class BasicTokenizer(object):
    """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

    def __init__(self, do_lower_case=True, split_on_punc=True):
        """Constructs a BasicTokenizer.

        Args:
            do_lower_case: Whether to lower case the input.
            split_on_punc: Whether to apply split on punctuations. By default BERT
                starts a new token for punctuations. This makes detokenization difficult
                for tasks like seq2seq decoding.
        """
        self.do_lower_case = do_lower_case
        self.split_on_punc = split_on_punc

    def tokenize(self, text):
        """Tokenizes a piece of text."""
        text = convert_to_unicode(text)
        text = self._clean_text(text)

        # This was added on November 1st, 2018 for the multilingual and Chinese
        # models. This is also applied to the English models now, but it doesn't
        # matter since the English models were not trained on any Chinese data
        # and generally don't have any Chinese data in them (there are Chinese
        # characters in the vocabulary because Wikipedia does have some Chinese
        # words in the English Wikipedia.).
        text = self._tokenize_chinese_chars(text)

        orig_tokens = whitespace_tokenize(text)
        split_tokens = []
        for token in orig_tokens:
            if self.do_lower_case:
                token = token.lower()
                token = self._run_strip_accents(token)
            if self.split_on_punc:
                split_tokens.extend(self._run_split_on_punc(token))
            else:
                split_tokens.append(token)

        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            cat = unicodedata.category(char)
            if cat == "Mn":
                continue
            output.append(char)
        return "".join(output)

    def _run_split_on_punc(self, text):
        """Splits punctuation on a piece of text."""
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        return ["".join(x) for x in output]

    def _tokenize_chinese_chars(self, text):
        """Adds whitespace around any CJK character."""
        output = []
        for char in text:
            cp = ord(char)
            if self._is_chinese_char(cp):
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # This defines a "chinese character" as anything in the CJK Unicode block:
        #     https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
        #
        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
        # despite its name. The modern Korean Hangul alphabet is a different block,
        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
        # space-separated words, so they are not treated specially and handled
        # like the all of the other languages.
        if ((cp >= 0x4E00 and cp <= 0x9FFF) or    #
                (cp >= 0x3400 and cp <= 0x4DBF) or    #
                (cp >= 0x20000 and cp <= 0x2A6DF) or    #
                (cp >= 0x2A700 and cp <= 0x2B73F) or    #
                (cp >= 0x2B740 and cp <= 0x2B81F) or    #
                (cp >= 0x2B820 and cp <= 0x2CEAF) or
                (cp >= 0xF900 and cp <= 0xFAFF) or    #
                (cp >= 0x2F800 and cp <= 0x2FA1F)):    #
            return True

        return False

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        for char in text:
            cp = ord(char)
            if cp == 0 or cp == 0xfffd or _is_control(char):
                continue
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)


class WordpieceTokenizer(object):
    """Runs WordPiece tokenziation."""

    def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """Tokenizes a piece of text into its word pieces.

        This uses a greedy longest-match-first algorithm to perform tokenization
        using the given vocabulary.

        For example:
            input = "unaffable"
            output = ["un", "##aff", "##able"]

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through `BasicTokenizer.

        Returns:
            A list of wordpiece tokens.
        """

        text = convert_to_unicode(text)

        output_tokens = []
        for token in whitespace_tokenize(text):
            chars = list(token)
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        return output_tokens


def _is_whitespace(char):
    """Checks whether `chars` is a whitespace character."""
    # \t, \n, and \r are technically control characters but we treat them
    # as whitespace since they are generally considered as such.
    if char == " " or char == "\t" or char == "\n" or char == "\r":
        return True
    cat = unicodedata.category(char)
    if cat == "Zs":
        return True
    return False


def _is_control(char):
    """Checks whether `chars` is a control character."""
    # These are technically control characters but we count them as whitespace
    # characters.
    if char == "\t" or char == "\n" or char == "\r":
        return False
    cat = unicodedata.category(char)
    if cat in ("Cc", "Cf"):
        return True
    return False


def _is_punctuation(char):
    """Checks whether `chars` is a punctuation character."""
    cp = ord(char)
    # We treat all non-letter/number ASCII as punctuation.
    # Characters such as "^", "$", and "`" are not in the Unicode
    # Punctuation class but we treat them as punctuation anyways, for
    # consistency.
    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
            (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False


def preprocess_text(inputs, remove_space=True, lower=False):
    """Preprocesses data by removing extra space and normalize data.

    This method is used together with sentence piece tokenizer and is forked from:
    https://github.com/google-research/google-research/blob/master/albert/tokenization.py

    Args:
        inputs: The input text.
        remove_space: Whether to remove the extra space.
        lower: Whether to lowercase the text.

    Returns:
        The preprocessed text.

    """
    outputs = inputs
    if remove_space:
        outputs = " ".join(inputs.strip().split())

    if six.PY2 and isinstance(outputs, str):
        try:
            outputs = six.ensure_text(outputs, "utf-8")
        except UnicodeDecodeError:
            outputs = six.ensure_text(outputs, "latin-1")

    outputs = unicodedata.normalize("NFKD", outputs)
    outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
    if lower:
        outputs = outputs.lower()

    return outputs

def encode_pieces(sp_model, text, sample=False):
    """Segements text into pieces.

    This method is used together with sentence piece tokenizer and is forked from:
    https://github.com/google-research/google-research/blob/master/albert/tokenization.py


    Args:
        sp_model: A spm.SentencePieceProcessor object.
        text: The input text to be segemented.
        sample: Whether to randomly sample a segmentation output or return a
            deterministic one.

    Returns:
        A list of token pieces.
    """
    if six.PY2 and isinstance(text, six.text_type):
        text = six.ensure_binary(text, "utf-8")

    if not sample:
        pieces = sp_model.EncodeAsPieces(text)
    else:
        pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
    new_pieces = []
    for piece in pieces:
        piece = printable_text(piece)
        if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
            cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
                    SPIECE_UNDERLINE, ""))
            if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
                if len(cur_pieces[0]) == 1:
                    cur_pieces = cur_pieces[1:]
                else:
                    cur_pieces[0] = cur_pieces[0][1:]
            cur_pieces.append(piece[-1])
            new_pieces.extend(cur_pieces)
        else:
            new_pieces.append(piece)

    return new_pieces

def encode_ids(sp_model, text, sample=False):
    """Segments text and return token ids.

    This method is used together with sentence piece tokenizer and is forked from:
    https://github.com/google-research/google-research/blob/master/albert/tokenization.py

    Args:
        sp_model: A spm.SentencePieceProcessor object.
        text: The input text to be segemented.
        sample: Whether to randomly sample a segmentation output or return a
            deterministic one.

    Returns:
        A list of token ids.
    """
    pieces = encode_pieces(sp_model, text, sample=sample)
    ids = [sp_model.PieceToId(piece) for piece in pieces]
    return ids

class FullSentencePieceTokenizer(object):
    """Runs end-to-end sentence piece tokenization.

    The interface of this class is intended to keep the same as above
    `FullTokenizer` class for easier usage.
    """

    def __init__(self, sp_model_file):
        """Inits FullSentencePieceTokenizer.

        Args:
            sp_model_file: The path to the sentence piece model file.
        """
        self.sp_model = spm.SentencePieceProcessor()
        self.sp_model.Load(sp_model_file)
        self.vocab = {
                self.sp_model.IdToPiece(i): i
                for i in six.moves.range(self.sp_model.GetPieceSize())
        }

    def tokenize(self, text):
        """Tokenizes text into pieces."""
        return encode_pieces(self.sp_model, text)

    def convert_tokens_to_ids(self, tokens):
        """Converts a list of tokens to a list of ids."""
        return [self.sp_model.PieceToId(printable_text(token)) for token in tokens]

    def convert_ids_to_tokens(self, ids):
        """Converts a list of ids ot a list of tokens."""
        return [self.sp_model.IdToPiece(id_) for id_ in ids]

####  Load BERT model

In [None]:
# Download BERT uncased WWM model (https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/1)

def load_bert_and_tokenizer(bert_model_path):
    # Load trainable BERT model and tokenizer
    bert_layer = hub.KerasLayer(bert_model_path,
                                trainable=True)
    if 'albert_en_large' in bert_model_path:
        sp_model_file = bert_layer.resolved_object.sp_model_file.asset_path.numpy()
        tokenizer = FullSentencePieceTokenizer(sp_model_file)
        return bert_layer,tokenizer
    
    elif 'bert_en_wwm_uncased_L-24_H-1024_A-16' in bert_model_path:
        vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
        do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
        tokenizer = FullTokenizer(vocab_file, do_lower_case)
        return bert_layer,tokenizer
    else:
        raise Exception("undefined model")



### Model building
adding layer for classification

In [None]:
### Need much resources.
def build_model(bert_layer,max_seq_length,lr=0.000001):
    # Finetune BERT model and add a classification layer
   
    input_word_ids=tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                           name="input_word_ids")
    input_masks=tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                       name="input_mask")
    segment_ids=tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                        name="segment_ids")
    
    pooled_output, sequence_output = bert_layer([input_word_ids, input_masks, segment_ids])
    
    first_token_tensor=sequence_output[:, 0, :]
    
    output=tf.keras.layers.Dense(1, activation='sigmoid')(first_token_tensor)
    model = tf.keras.models.Model(inputs=[input_word_ids,input_masks,segment_ids], outputs=output)
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=lr), loss='binary_crossentropy', 
                      metrics=[f1,'accuracy'])
    return model

## Model training

In [None]:
def finetune_bert(bert_layer,max_seq_length,learning_rate,batch_size,
                  X_id_train, X_mask_train, X_seg_train,y_train,
                  X_id_valid,X_mask_valid,X_seg_valid,y_valid,
                  class_weight,filename,output_folder,epochs=15):
    # Train BERT model
    print("***** Finetune BERT model *****")
    model=build_model(bert_layer,max_seq_length,lr=learning_rate)


    file_path = filename+".hdf5"
    check_point = ModelCheckpoint(file_path, monitor = "val_loss", verbose = 1, save_best_only = True, mode = "min") # only save the best epoch
    early_stop = EarlyStopping(monitor = "val_loss", mode = "min", patience=3) # Stop when val_loss does not improve
    hist = model.fit([X_id_train, X_mask_train, X_seg_train], 
                     y_train, batch_size=batch_size, epochs=epochs, 
                     validation_data=([X_id_valid,X_mask_valid,X_seg_valid], y_valid), verbose=2,
                     class_weight=class_weight,
                     shuffle=True, callbacks = [check_point, early_stop])
    model.load_weights(filename+".hdf5") # Load the best epoch
    
    # Check loss values 
    loss=hist.history['loss']
    val_loss=np.array(hist.history['val_loss'])
    min_loss = val_loss.min()
    best_epoch=val_loss.argmin()
    max_f1=max(hist.history['val_f1'])
    print("max f1:%f"%max_f1)
    print("min_loss:%f"%min_loss)
    print("best epoch:%d"%best_epoch)
    
    with open('log.txt','a') as f:
        f.write("%s, f1: %.3f, min_loss: %.3f, best epoch: %d\n"%(filename,max_f1,min_loss,best_epoch))
    plot_loss(loss,val_loss,path=output_folder+'/'+filename+"_bert_Loss.png")
    return model
    

## Prediction

In [None]:
# plot confusion matrix 
def plot_con_matrix(confmat,path):
   
    fig, ax = plt.subplots(figsize=(3, 3))
    ax.matshow(confmat, cmap=plt.cm.Blues, alpha=0.3)
    for i in range(confmat.shape[0]):
        for j in range(confmat.shape[1]):
            ax.text(x=j, y=i, s=confmat[i,j], va='center', ha='center')
    plt.xlabel('predicted label')        
    plt.ylabel('true label')
    plt.savefig(path)
    
# prediction score
def prediction_scores(test_df,output_folder,filename):
    test_same=pd.read_csv(output_folder+'/'+filename,index_col=0)
    test_df.loc[test_same.index,'target']=test_same['target']
    test_df.loc[test_same.index,'pred_prob']=test_same['pred_prob']
    test_labels=test_df['label'].replace({-1:0})
    pred_class=np.array(test_df['target'].values)
    confmat=confusion_matrix(test_labels, pred_class)
    precision=precision_score(test_labels, pred_class)
    recall=recall_score(test_labels, pred_class)
    f1=f1_score(test_labels, pred_class)
    print("Scores for "+output_folder+'/'+filename)
    print("precision score:%.3f"%precision)
    print("recall score:%.3f"%recall)
    print("fl score:%.3f"%f1)
    plot_con_matrix(confmat,'%s/confusion_matrix_bert.png'%output_folder)
    test_df.to_csv(output_folder+'/full_'+filename)
    
    
# Predict
def make_prediction(test_df,prediction,output_folder,submission_name):
    # Predict
    
    pred_class=np.array(prediction>0.5,dtype=int)
    test_df['target']=pred_class
    sub=test_df['target']
    sub.to_csv(output_folder+'/'+submission_name+'.csv',header=True)
    sub=pd.DataFrame(sub)
    sub['pred_prob']=prediction
    sub.to_csv(output_folder+'/'+submission_name+'_prob.csv',header=True)
    prediction_scores(test_df,output_folder,submission_name+'_prob.csv')


# display f1 score during running
def f1(y_true, y_pred):
    def recall(y_true, y_pred):
        """Recall metric.

        Only computes a batch-wise average of recall.

        Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        """Precision metric.

        Only computes a batch-wise average of precision.

        Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

# plot loss function
def plot_loss(loss,val_loss,path="Model_Loss.png"):
    # Plot training and validation loss values
    plt.figure()
    plt.plot(loss)
    plt.plot(val_loss)
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epochs')
    plt.legend(['train','valid'],loc='upper left')
    plt.savefig(path)



### Let's run the BERT model

In [None]:
#tf.enable_eager_execution() # For tensorflow 1
#gpus = tf.config.experimental.list_physical_devices('GPU')
#tf.config.experimental.set_memory_growth(gpus[0], True)
bert_model_path='https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/1'
#bert_model_path='../albert_en_xlarge'

'''
Hyperparameters. 
'''
learning_rate=0.000008
lr_class_model=0.0005
batch_size=4
max_seq_length=512 # longest sequence after converting is 1200. But BERT can only accept up to 512.


''' 
Data pre-processing
'''

# Load training data
train_df = pd.read_csv(f'{bert_data_path}/train_sample.csv',encoding='utf-8',na_filter= False,
                       index_col=0,dtype={'title':str,'abstract':str})
train_df['title_count']=train_df['title'].str.split().apply(len)
train_df['abstract_count']=train_df['abstract'].str.split().apply(len)
train_df['total_count']=train_df['title_count']+train_df['abstract_count']
train_df['label']=train_df['label'].replace({-1:0})
print(train_df['total_count'].max()) # Max word count

# Load test data
test_df  = pd.read_csv(f'{bert_data_path}/test_sample.csv',encoding='utf-8',na_filter= False,
                       index_col=0,dtype={'title':str,'abstract':str})
test_df['title_count']=test_df['title'].str.split().apply(len)
test_df['abstract_count']=test_df['abstract'].str.split().apply(len)
test_df['total_count']=test_df['title_count']+test_df['abstract_count']
test_df['label']=test_df['label'].replace({-1:0})
print(test_df['total_count'].max())

if if_finetune_bert_model: # Finetune BERT model
    # Load tokenizer for BERT model
    print("Load tokenizer")
    bert_layer,tokenizer=load_bert_and_tokenizer(bert_model_path)

    train_target=np.array(train_df['label'])

    # Unbalance data. Use class_weight
    neg_target=(train_target==0).sum()
    pos_target=(train_target==1).sum()
    print("Negtive samples: %d"%neg_target)
    print("Positive samples: %d"%pos_target)
    pos_weight=neg_target/pos_target
    class_weight = {0: 1, 1: pos_weight} 
    print("class_weight",class_weight)

    # Convert text format for BERT model
    train_texts_input_ids, train_texts_input_masks, train_texts_segment_ids=convert_text_format(
            train_df,tokenizer,max_seq_length) 

    test_texts_input_ids, test_texts_input_masks, test_texts_segment_ids=convert_text_format(
            test_df,tokenizer,max_seq_length)

    '''
    Split 80% data for training and 20% data for validation.
    '''    
    from sklearn.model_selection import train_test_split
    X_id_train, X_id_valid, X_mask_train, X_mask_valid,X_seg_train, X_seg_valid,y_train, y_valid = \
            train_test_split(
            train_texts_input_ids, train_texts_input_masks, 
            train_texts_segment_ids, 
            train_target, test_size=0.2)
    
    model=finetune_bert(bert_layer,max_seq_length,learning_rate,batch_size,
                  X_id_train, X_mask_train, X_seg_train,y_train,
                  X_id_valid,X_mask_valid,X_seg_valid,y_valid,
                  class_weight,bert_model_name,output_folder,epochs=10)
    print(model.summary())

    # Load the best epoch
    model=build_model(bert_layer,max_seq_length,lr=learning_rate)
    model.load_weights(bert_model_name+".hdf5")

    # Predict our PubMed test data 
    X_test=[test_texts_input_ids, test_texts_input_masks, test_texts_segment_ids] # input for BERT model
    submission_name='submission_bert'
    pred_bert = model.predict(X_test) 
    make_prediction(test_df,pred_bert,output_folder,submission_name)

In [None]:
def load_test_prediction(input_folder,output_folder,filename,model):
    # Load test data
    print('Load %s/%s.csv'%(input_folder,filename))
    test_df  = pd.read_csv('%s/%s.csv'%(input_folder,filename),encoding='utf-8',na_filter= False,
                           index_col=0,dtype={'title':str,'abstract':str}) # read test data
    test_df=test_df[test_df['abstract'].notnull()]
    test_df.fillna(' ',inplace=True)
    if 'meta' in filename:
        test_df=test_df[['pubmed_id','title','abstract']]
    else:
        test_df=test_df[['title','abstract']]

    test_texts_input_ids, test_texts_input_masks, test_texts_segment_ids=convert_text_format(
            test_df,tokenizer,max_seq_length) # texts to BERT model format
    X_test=[test_texts_input_ids, test_texts_input_masks, test_texts_segment_ids] # input for BERT model
    submission_name='%s_submission_bert'%filename
    prediction = model.predict(X_test)
    pred_class=np.array(prediction>0.5,dtype=int)
    test_df['prediction']=pred_class
    test_df['pred_prob']=prediction
    test_df.to_csv(output_folder+'/'+submission_name+'.csv',header=True)

In [None]:
if if_finetune_bert_model: # Finetune BERT model
    # Predict COVID-19 data
    filenames=['model_positive','model_negative']
    input_folder=f'{bert_data_path}'
    output_folder=f'{bert_data_path}'
    for filename in filenames:
        load_test_prediction(input_folder,output_folder,filename,model)

# Conclusion of document retrieval

In this section, we show the results of the trained model, as well as providing discussions of its performance. Note that the model is applied to the articles containing KOI (keywords of Interest) and articles without KOI, respectively. The probability cutoff for calling positive predictions is 0.5.

| Cross-table          | Predicted positive | Predicted negative | Total |
| -----------          | -----------------: | -----------------: | ----: |
| Articles without KOI |               374  |              33019 | 33393 |

We can see that although some articles do not contain KOI, our model still predicted them to be relevant to KOI. We sorted the predicted probabilities for the articles without KOI and manually read some of the articles to check whether they are indeed related to KOI. We will highlight the texts that indicate actual meaning of risk factor, followed by our analysis.

Positive examples with high prediction score: 

1. ***Article:*** paper id: nj2707mv, Prediction Score: 0.990073	
    - ***Title:*** Association of Dynamic Changes in the ***CD4 T-Cell Transcriptome With Disease Severity During Primary Respiratory Syncytial Virus Infection*** in Young Infants
    - ***Abstract:*** BACKGROUND: Nearly all children are infected with respiratory syncytial virus (RSV) within the first 2 years of life, with a minority developing severe disease (1%–3% hospitalized). We hypothesized that an assessment of the adaptive immune system, using CD4(+) T-lymphocyte transcriptomics, would identify gene expression correlates of disease severity. METHODS: Infants infected with RSV representing extremes of clinical severity were studied. Mild illness (n = 23) was defined as a ***respiratory rate (RR) < 55 and room air oxygen saturation (SaO(2)) ≥ 97%***, and severe illness (n = 23) was defined as RR ≥ 65 and SaO2 ≤ 92%. RNA from fresh, sort-purified CD4(+) T cells was assessed by RNA sequencing. RESULTS: ***Gestational age, age at illness onset, exposure to environmental tobacco smoke, bacterial colonization, and breastfeeding*** were associated (adjusted P < .05) with disease severity. RNA sequencing analysis reliably measured approximately 60% of the genome. Severity of RSV illness had the greatest effect size upon CD4 T-cell gene expression. Pathway analysis identified correlates of severity, including JAK/STAT, prolactin, and interleukin 9 signaling. We also identified genes and pathways associated with timing of symptoms and RSV group (A/B). CONCLUSIONS: These data suggest fundamental changes in adaptive immune cell phenotypes may be associated with RSV clinical severity.
    - ***Analysis:*** This paper talks about the risk factors of disease severity during primary Respiratory Syncytial Virus(RSV) infection. Risk factors include: gestational age, age at illness onset, exposure to environmental tobacco smoke, bacterial colonization, and breastfeeding. It is obvious that this pape is a positive case and the prediction score is also high. 
    
2. ***Article:*** Paper id: 1aluscl8, Prediction Score: 0.986705
    - ***Title:*** Ebola Hemorrhagic Fever as a Public Health Emergency of International Concern; a Review Article
    - ***Abstract:*** Ebola hemorrhagic fever (EHF) was first reported in 1976 with two concurrent outbreaks of acute viral hemorrhagic fever centered in Yambuku (near the Ebola river), Democratic Republic of Congo, and in Nzara, Sudan. The current outbreak of the Ebola virus was started by reporting the first case in March 2014 in the forest regions of southeastern Guinea. Due to infection rates raising over 13,000% within a 6-month period, Ebola is now considered as a global public health emergency and on August 8(th), 2014 the World Health Organization (WHO) declared the epidemic to be a Public Health Emergency of International Concern. With more than 5000 involved cases and nearly 3000 deaths, this event has turned into the largest and most dangerous Ebola virus outbreak in the world. Based on the above-mentioned, the present article aimed to review the ***virologic characteristics, transmission, clinical manifestation, diagnosis, treatment, and prevention of Ebola virus disease.***
    - ***Analysis:*** This paper pays attention to the disease Ebola hemorrhagic fever (EHF). Since it is aimed to review the virologic characteristics, transmission, clinical manifestation, diagnosis, treatment, and prevention, we can induce that the the paper will talk about risk factors of this disease.

3. ***Article:*** Paper id: t02jngq0, Prediction Score: 0.98595
    - ***Title:*** A database of geopositioned Middle East Respiratory Syndrome Coronavirus occurrences
    - ***Abstract:*** As a World Health Organization Research and Development Blueprint priority pathogen, there is a need to better understand the ***geographic distribution of Middle East Respiratory Syndrome Coronavirus (MERS-CoV) and its potential to infect mammals and humans.*** This database documents cases of MERS-CoV globally, with specific attention paid to zoonotic transmission. An initial literature search was conducted in PubMed, Web of Science, and Scopus; after screening articles according to the inclusion/exclusion criteria, a total of 208 sources were selected for extraction and geo-positioning. Each MERS-CoV occurrence was assigned one of the following classifications based upon published contextual information: index, unspecified, secondary, mammal, environmental, or imported. In total, this database is comprised of 861 unique geo-positioned MERS-CoV occurrences. The purpose of this article is to share a collated MERS-CoV database and extraction protocol that can be utilized in future mapping efforts for both MERS-CoV and other infectious diseases. More broadly, it may also provide useful data for the development of targeted MERS-CoV surveillance, which would prove invaluable in preventing future ***zoonotic spillover***.
    - ***Analysis:*** This paper aims to analyze specific risk factors of Middle East Respiratory Syndrome Coronavirus((MERS-CoV)). Geographic distribution and potential to infect mammals and humans are two main aspects that the authors concerned about. Through the study of the database, the readers can know how these two factors impact the spread of (MERS-CoV).

4. ***Article:*** Paper id: 9t5ncsig, Prediction Score: 0.985887
    - ***Title:*** Optimising Renewal Models for Real-Time Epidemic Prediction and Estimation
    - ***Abstract:*** The effective reproduction number, ***Rt, is an important prognostic for infectious disease epidemics. Significant changes in Rt can forewarn about new transmissions or predict the efficacy of interventions.*** The renewal model infers Rt from incidence data and has been applied to Ebola virus disease and pandemic influenza outbreaks, among others. This model estimates Rt using a sliding window of length k. While this facilitates real-time detection of statistically significant Rt fluctuations, inference is highly k -sensitive. Models with too large or small k might ignore meaningful changes or over-interpret noise-induced ones. No principled k -selection scheme exists. We develop a practical yet rigorous scheme using the accumulated prediction error (APE) metric from information theory. We derive exact incidence prediction distributions and integrate these within an APE framework to identify the k best supported by available data. We find that this k optimises short-term prediction accuracy and expose how common, heuristic k -choices, which seem sensible, could be misleading.
    - ***Analysis:*** This article summarizes a study about the renewal model that infers effective reproduction number, Rt, of Ebola virus disease and pandemic influenza outbreak. The effective reproduction number is a basic characteristic of pandemic disease, and is also an important risk factor to the spread of a pandemic. From this study, the readers can know how Rt impact the spread of Ebola virus disease and pandemic influenza. Therefore it may shed light to the control and prevention of such disease.

5. ***Article:*** Paper id: 5zcydnre, Prediction Score: 0.984713
    - ***Title:*** The time-series ***ages distribution*** of the reported COVID-2019 infected people suggests the undetected ***local spreading*** of COVID-2019 in Hubei and Guangdong provinces before 19th Jan 2020
    - ***Abstract:*** COVID-2019 is broken out in China. It becomes a severe public health disaster in one month. Find the period in which the spreading of COVID-2019 was overlooked, and understand the epidemiological characteristics of COVID-2019 in the period will provide valuable information for the countries facing the threats of COVID-2019. The most extensive epidemiological analysis of COVID-2019 ***shows that older people have lower infection rates compared to middle-aged persons***. Common sense is that older people prefer to report their illness and get treatment from the hospital compared to middle-aged persons. ***We propose a hypothesis that when the spreading of COVID-2019 was overlooked, we will find more older cases than the middle-aged cases.*** At first, we tested the hypothesis with 4597 COVID-2019 infected samples reported from 26th Nov 2019 to 17th Feb 2020 across the mainland of China. We found that 19th Jan 2020 is a critical time point. Few samples were reported before that day, and most of them were older ones. Then samples were explosively increased after that day, and many of them were middle-aged people. We have demonstrated the hypothesis to this step. Then, we grouped samples by their residences(provinces). We found that, in the provinces of Hubei and Guangdong, the ages of samples reported before 19th Jan 2020 are significantly higher than the ages of samples reported after that day. It suggests the COVID-2019 may be spreading in Hubei and Guangdong provinces before 19th Jan 2020 while people were unconscious of it. At last, we proposed that ***the ages distribution*** of each-day-reported samples could serve as a ***warning indicator*** of whether all potential COVID-2019 infected people are found. We think the power of our analysis is limit because 1. the work is data-driven, and 2. only ~5% of the COVID-2019 infected people in China are included in the study. However, we believe it still shows some value for its ability to estimate the possible unfound COVID-2019 infected people.
    - ***Analysis:*** This article summarizes a study on COVID-19 infected people in Hubei and Guangdong provinces of China. We can see that the main focus of this study is the distribution of age group of patients. Age is a key risk factor related to COVID-19. Specifically, the article points out that age distribution can serve as a warning indicator of whether all potentially infected people are all found. This will shed light to treating patients better and slowing down the spread of disease.

In conclusion, we see from the above examples that these articles are indeed relevant to risk factors to the outbreak, spread, infection, or mortality of a disease in some way, and to some degree, although the direct keyword "risk" and "factor" have never occurred in these articles.  


# Information extraction

In information extraction, we extract interesting relationships in the relevant articles to build an interactive network through which researchers can navigate the existing knowledge about COVID-19 on the topic defined by KOI.
A function was developed to extract relationships (edges) between a pair of concepts. Using the function, we tokenized each sentence in the titles and abstracts into words after removing stop words. Word normalization and lemmatization were also performed. An edge was constructed for each pair of words in the sentence. The words and edges from all the sentences then form a network. This network contains many trivial edges. For example, if "protein interaction network" is in the sentence and they are treated as three separate words, then there will be three edges between any pair of the two words. These edges make the network look very crowded. We would like to filter out less interesting edges.

To do that, we used a large number of PubMed articles queried using KOI, but not in the COVID-19 dataset, to construct a background network. We subtract the edges in the background network from the network constructed for COVID-19. 


In [None]:
# Load titles of all papers in the covid-19 dataset to exclude papers
# in both the covid-19 dataset and the pubmed papers for background edges 
covid19_titles = set(df.title.str.strip().str.lower())

# Function for getting edges and related sentences
def get_edges_sents(df, phrases = '', pubmed = False):
    """
    Develop edges between lemmatized words or phrases in a sentence,
    and keep the relevant sentences for each edge.
    Return a dictionary of counts for each edge, and a dictionary of
    relevant sentences for each edge. 
    Args:
        df: a dataframe of titles and abstracts by ids.
        phrases: a list of phases split by ',' which containing words need to be considered as phrases.
        pubmed: for pubmed titles and abstracts, relevant sentences for each edge will not be kept.
    """
    assert isinstance(phrases, str)
    edges_and_counts = {}
    sents_per_edge = {}
    if phrases != '':
        phrases = ','.join([' '.join([wnlem.lemmatize(word.lower()) for word in phrase.split()]) for phrase in phrases.split(',')])
        phrases = [phrase.split() for phrase in phrases.split(',')]
    for idx, row in df.iterrows():
        if pubmed:
            title = row.title.strip()
            if title.endswith('.'): title = title[:-1]
            if title in covid19_titles: continue
        title_abstract = ' '.join([row.title, row.abstract])
        if not row.title.strip().endswith('.'):
            title_abstract = ' '.join([row.title.strip() + '.', row.abstract])
        title_abstract = text_preprocess(title_abstract)
        sentences = sent_tokenize(title_abstract)
        for j, sentence in enumerate(sentences):
            try:
                if model.predict(sentence)[0][0] != '__label__en': continue
            except: continue
            if 'ELECTRONIC SUPPLEMENTARY MATERIAL' in sentence: continue
            if 'CC-BY' in sentence: continue
            sentence = sentence.strip()
            if re.match(r'^(the )?author funder,', sentence): continue
            sentence = sentence_preprocess(sentence)
            word_list = word_tokens(sentence)
            if len(word_list) < 5 or len(word_list) > 50: continue
            word_list = list(set(word_list))
            if phrases != '':
                phrase_list = set()
                for phrase in phrases:
                    count = 0
                    for word in phrase:
                        if word in word_list: count += 1
                    if count == len(phrase):
                        word_list = [w for w in word_list if w not in phrase]
                        phrase_list.add(' '.join(phrase))
                word_list = list(phrase_list | set(word_list))
            word_list.sort()
            pairs = combinations(word_list, 2)
            for pair in pairs:
                word_pair = '|'.join(list(pair))
                edges_and_counts[word_pair] = edges_and_counts.get(word_pair, 0) + 1
                if pubmed: continue
                sentinfo = {'sentence':sentence, 'title':row.title, 'authors':row.authors,
                            'journal':row.journal, 'publish_time':row.publish_time, 'url':row.url}
                if word_pair in sents_per_edge: sents_per_edge[word_pair].append(sentinfo)
                else: sents_per_edge[word_pair] = [sentinfo]
    return edges_and_counts, sents_per_edge

# function to write covid19 edges and sents into tsv files
def write_edges_sents(covid19_edges, covid19_sents_per_edge, pubmed_edges, dataset_name):
    """
    Write the nodes of each edge, counts of the edge and relevant sentences of the edge
    into a csv file
    Args:
        covid19_edges: the dictionary of counts for each edge of the covid-19 subset data.
        covid19_sents_per_edge: the dictionary of relevant sentences for each edge of the covid-19 subset data.
        pubmed_edges: the dictionary of counts for each edge of the pubmed background edge data.
        dataset_name: a string of text as the csv file name
    """
    assert isinstance(dataset_name, str)
    sents_file = open(f'{edges_path}/covid19_edges_sents_{dataset_name}.csv', 'w',
                      encoding = 'utf-8', newline = '')
    writer = csv.writer(sents_file)
    for k, v in covid19_edges.items():
        if k in pubmed_edges: continue
        writer.writerow(k.split('|') + [str(v), covid19_sents_per_edge[k]])
    sents_file.close()

Develop edges of a task

In [None]:
# edges of a task
if not if_run_all:
    df_tiabs = pd.read_csv(f'{pubmed_tiabs_path}/covid19_tiabs_task_{task_num}.csv', na_filter= False)
    df_subset = pd.read_csv(f'{covid19_dataset_path}/covid19_subset_task_{task_num}.csv', na_filter= False)
pubmed_edges, _ = get_edges_sents(df_tiabs, KOI, pubmed = True)
covid19_edges, covid19_sents_per_edge = get_edges_sents(df_subset, KOI)
covid19_edges = dict(sorted(covid19_edges.items(), key=lambda x:x[1], reverse=True))
write_edges_sents(covid19_edges, covid19_sents_per_edge, pubmed_edges, f'task_{task_num}')

# Words that co-occur with KOI

In order to ensure the network constructed contains information of key concepts and relationships relevant to KOI, we developed a function to extract the set of words co-occurs with KOI in the same sentence. This set of words is used together with the network to presents the key concepts and relationships relevant to KOI.

In [None]:
# Function for getting vocabulary co-occuer with key words
def get_cooccur_vocab(df, phrases, if_and = False):
    """
    Develop a vocabulary of nouns co-occur with the keywords in the phrases.
    Return a dictionary of paper counts and sentence counts of each noun.
    Args:
        df: a dataframe of titles and abstracts.
        phrases: a list of phases split by ','
        if_and: if True, develop vocabulary co-occur with all keywords in the phrases,
                if False, develop vocabulary co-occur with any keywords in the phrases
    """
    assert isinstance(phrases, str) and phrases != ''
    cooccur_vocab = {}
    phrases = ','.join([' '.join([wnlem.lemmatize(word.lower()) for word in phrase.split()]) for phrase in phrases.split(',')])
    phrases = [phrase.split() for phrase in phrases.split(',')]
    for idx, row in df.iterrows():
        title_abstract = ' '.join([row.title, row.abstract])
        if not row.title.strip().endswith('.'):
            title_abstract = ' '.join([row.title.strip() + '.', row.abstract])
        title_abstract = text_preprocess(title_abstract)
        sentences = sent_tokenize(title_abstract)
        doc_vocab = {}
        for j, sentence in enumerate(sentences):
            try:
                if model.predict(sentence)[0][0] != '__label__en': continue
            except: continue
            if 'ELECTRONIC SUPPLEMENTARY MATERIAL' in sentence: continue
            if 'CC-BY' in sentence: continue
            sentence = sentence.strip()
            if re.match(r'^(the )?author funder,', sentence): continue
            sentence = sentence_preprocess(sentence)
            word_list = word_tokens(sentence)
            if len(word_list) < 5 or len(word_list) > 50: continue
            word_list = list(set(word_list))
            phrase_list = set()
            count_phrase = 0
            for phrase in phrases:
                count_word = 0
                for word in phrase:
                    if word in word_list: count_word += 1
                if count_word == len(phrase):
                    word_list = [w for w in word_list if w not in phrase]
                    phrase_list.add(' '.join(phrase))
                    count_phrase += 1
            if count_phrase == 0: continue
            noun_list = tagged_word_tokens(sentence)
            word_list = [w for w in word_list if w in noun_list]
            if if_and:
                if count_phrase != len(phrases): continue
            word_list = list(phrase_list | set(word_list))
            for w in word_list:
                doc_vocab[w] = doc_vocab.get(w, 0) + 1
        for w, c in doc_vocab.items():
            if w in cooccur_vocab:
                cooccur_vocab[w][0], cooccur_vocab[w][1] = cooccur_vocab[w][0] + 1, cooccur_vocab[w][1] + c
            else: cooccur_vocab[w] = [1, int(c)]
    return cooccur_vocab

Develop co-occuering vocabulary for a task

In [None]:
# co-occuering vocabulary for a task
# df_subset = pd.read_csv(f'{covid19_dataset_path}/covid19_subset_task_{task_num}.csv', na_filter= False)
df_subset = df_subset[['cord_uid', 'title', 'abstract']]
cooccur_vocab = get_cooccur_vocab(df_subset, KOI)
cooccur_vocab = dict(sorted(cooccur_vocab.items(), key = lambda x:x[1][1], reverse = True))
json.dump(cooccur_vocab, open(f'{edges_path}/covid19_cooccur_vocab_task_{task_num}.json',
                               'w', encoding = 'utf-8'), indent = 4)

# Display Knowledge Network
We use Risk Factor as selected key for example, we first extract two types (I and II) of edges from the literature as described above. Type I is the edges of all possible pairwise words in the same sentence; Type II is the edges of all possible words co-occurred with the selected key word(s) in the same sentence. 

We processed the edges in following steps: 
```
1. Rank edges: 
    The two types of edges are ranked by counts from high to low. For the type II edges, we rank them by the count of how many papers containing the edges, instead of how many sentences. 
2. Select nodes: 
    For type I edges, the top 100 nodes ranked by degree (number of edges connecting to the node) from high to low were selected. For type II edges, a intersect of the top 100 type I nodes and all type II nodes were selected. 
```

In [None]:
# import some useful packages
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import networkx as nx
import pandas as pd
import random
import json
import os
import re
import difflib
import random
import requests
import ipywidgets as widgets
from IPython.display import display
from random import sample 
from ipywidgets import HBox, Label

# latex rendering of text in graphs
import matplotlib as mpl
mpl.rc('text', usetex = False)
mpl.rc('font', family = 'serif')

%matplotlib inline

!pip install visJS2jupyter
# import visJS_2_jupyter 
from visJS2jupyter import visJS_module

In [None]:
# function get_sub_edges is to extract a subset of edges from input edges dataframe
# df_edges: input edge dataframe
# node_cooccur: keywords in type II edges
# top_node: number of top type I nodes ranked by degrees to be selected to display in network 
# top_node_cooccur: number of type II nodes to be selected to display in network
# edge_top_perc: how many percentage of edges used for calculation of top_node

def get_sub_edges(df_edges, node_cooccur, top_node = 100, 
                  top_node_cooccur = 1, edge_top_perc= 10): 
    edge_num_total = df_edges.shape[0]
    edge_num = int(edge_num_total*edge_top_perc/100)
    edge_df_sub = df_edges.iloc[0:edge_num, :]
    G_covid19 = nx.from_pandas_edgelist(edge_df_sub,source='term1',target='term2')
    sorted_node = sorted(G_covid19.degree, key=lambda x: x[1], reverse=True)

    # select the top 100 nodes ranked by degree
    total_type1_node = 1000
    type1_nodes = [] # top 100 node
    if len(sorted_node) < total_type1_node:
        total_type1_node = len(sorted_node)
        
    for i in range(0,total_type1_node):
        type1_nodes.append(sorted_node[i][0])
    
    # all the co-occurred nodes within top 100 nodes
    type2_nodes = []
    for node_c in list(node_cooccur.keys()):
        if node_c in type1_nodes:
            type2_nodes.append(node_c)
    
    # type1 nodes to be display
    selected_type1_nodes = []
    if len(sorted_node) < top_node:
        top_node = len(sorted_node)
        
    for i in range(0,top_node):
        selected_type1_nodes.append(sorted_node[i][0])
        
    selected_type2_nodes = []
    len_type2_nodesccur = len(type2_nodes)
    if top_node_cooccur > len_type2_nodesccur:
        top_node_cooccur = len_type2_nodesccur
    if top_node_cooccur > 0:
        selected_type2_nodes = type2_nodes[0:top_node_cooccur]

    for node in selected_type2_nodes:
        if node not in selected_type1_nodes:
            selected_type1_nodes.append(node)

    edge_df_sel = edge_df_sub[edge_df_sub.term1.isin(selected_type1_nodes) & edge_df_sub.term2.isin(selected_type1_nodes)] 
            
       
    return(selected_type1_nodes, len_type2_nodesccur, selected_type2_nodes, edge_df_sel)


# function show_sentences is to format sentences to show when clicking on edge
# edges_with_data: edges contains sentences info
# max_words: the maximum number words in each line for sentence display
# max_sentences: the maximum number of sentenes to display
def show_sentences(edges_with_data, max_words = 16, max_sentences = 10):
    for i in range(len(edges_with_data)):
#         sents = eval(edges_with_data[i][2]['sentences'])
        sents = edges_with_data[i][2]['sentences']
        mystring = "Edge " + "--".join(edges_with_data[i][0:2]) + " derived from sentences like: <br/> <br/> "
        num_paper = len(sents)
        sel_paper = range(num_paper)
        if num_paper > max_sentences:
            sel_paper = sample(range(num_paper),max_sentences)
        for k in sel_paper:
            sent_p1 = sents[k]['sentence']
            sent_p2 = sents[k]['url']         
            TT_s = sent_p1.split(" ")
            if len(TT_s) > max_words:
                for j in range(0, int(len(TT_s)/max_words)+1):
                    mystring = mystring + " ".join(TT_s[max_words*j:max_words*j+max_words]) + " <br/>"
            else:
                mystring = mystring + " ".join(TT_s) + " <br/>"

            mystring = mystring + sent_p2 + " <br/><br/>"

        edges_with_data[i][2]['sentences'] = mystring

    return(edges_with_data)           
    

# function generate_network is to draw network
# edge_df_sel: selected edges dataframe, returned from function get_sub_edges
# key_node: KEY_WORD
# scaling_factor: control the resolution of network, larger number, higher resolution
def generate_network(edge_df_sel, k_nodes, scaling_factor):
    key_nodes = k_nodes
    G_covid19 = nx.from_pandas_edgelist(edge_df_sel,source='term1',target='term2',edge_attr = ['score','sentences'])

    nodes = list(G_covid19.nodes()) # must cast to list to maintain compatibility between nx 1.11 and 2.0
    edges = list(G_covid19.edges()) # will return an "EdgeView" object in nx 2.0

    pos = nx.spring_layout(G_covid19,k=1.0)

    nodes = list(G_covid19.nodes()) # to make compatible between nx 1.11 and 2.0, must cast to list

    numnodes = len(nodes)
    edges = list(G_covid19.edges()) # to make compatible between nx 1.11 and 2.0, must cast to list
    edges_with_data = list(G_covid19.edges(data=True)) # to make compatible between nx 1.11 and 2.0, must cast to list
    numedges = len(edges)

    # add a node attributes to color-code by
    cc = nx.clustering(G_covid19)
    degree = dict(G_covid19.degree()) # to make compatible between nx 1.11 and 2.0, must cast to dict
    bc = nx.betweenness_centrality(G_covid19)
    nx.set_node_attributes(G_covid19, name = 'clustering_coefficient', values = cc) # must explicitly define arguments
    nx.set_node_attributes(G_covid19, name = 'degree', values = degree)             # for compatibility with nx 1.11 and 2.0
    nx.set_node_attributes(G_covid19, name = 'betweenness_centrality', values = bc)

    # set node_size to degree
    node_size = [int(float(n)/np.max(list(degree.values()))*300+1) for n in list(degree.values())]
    node_to_nodeSize = dict(zip(degree.keys(),node_size))

    sorted_nodeSize = sorted(node_to_nodeSize.items(), key=lambda kv: kv[1], reverse=True)
    max_nodeSize = sorted_nodeSize[0][1]
    
    # # add nodes to highlight (key node)
    nodes_HL = [0 if node in key_nodes else 0 for node in G_covid19.nodes()]  
    nodes_HL = dict(zip(G_covid19.nodes(),nodes_HL))

    nodes_shape=[]
    node_shape = ['hexagon' if (node in key_nodes) else 'dot' for node in G_covid19.nodes()]
    node_to_nodeShape=dict(zip(G_covid19.nodes(),node_shape))

    # add a field for node labels
    node_labels_temp = []
    # list_of_genes = list(np.setdiff1d(G_covid19.nodes(),d_list))
    for node in G_covid19.nodes():
        label_temp = node
        label_temp+= '<br/>'
        label_temp+='degree: ' + str(nx.degree(G_covid19,node)) + '<br/>'

        node_labels_temp.append(label_temp)

    node_labels = dict(zip(G_covid19.nodes(),node_labels_temp))

    node_titles = [node for node in G_covid19.nodes()]

    node_titles = dict(zip(G_covid19.nodes(),node_titles))


    edges_with_data = show_sentences(edges_with_data, max_words = 16)


#     node_to_color = visJS_module.return_node_to_color(G_covid19,field_to_map='degree',cmap=mpl.cm.jet,alpha = 1)
    node_to_color = dict(zip(nodes, ['rgba(85, 85, 85, 1)']*len(nodes)))
    for k_node in key_nodes:
        node_to_color[k_node] = 'green'

    edge_to_color = visJS_module.return_edge_to_color(G_covid19,field_to_map = 'score',cmap=mpl.cm.Blues,alpha=1)

    nodes_dict = [{"id":n,"degree":G_covid19.degree(n),"color":node_to_color[n],
                  "node_size":node_to_nodeSize[n],'border_width':nodes_HL[n],
                  "node_label":node_labels[n],
                  "title":node_labels[n],
                  "node_shape":node_to_nodeShape[n],
                  "x":pos[n][0]*500*scaling_factor,
                  "y":pos[n][1]*500*scaling_factor} for n in nodes
                  ]


    node_map = dict(zip(nodes,range(numnodes)))  # map to indices for source/target in edges

    # edges_dict = [{"source":node_map[edges[i][0]], "target":node_map[edges[i][1]], 
    #               "color":edge_to_color[edges[i]],"title":edges_with_data[i][2]['score']} for i in range(numedges)]
    edges_dict = [{"source":node_map[edges[i][0]], "target":node_map[edges[i][1]], 
                  "e_label":edges_with_data[i][2]['sentences']} for i in range(numedges)]




    p = visJS_module.visjs_network(nodes_dict, edges_dict,
                               node_color_highlight_border="white",
                               node_color_hover_border = 'orange',
                               node_color_hover_background = 'rgb(119, 130, 140, 1)',
                               node_color_border='black',
                               node_size_field='node_size',
                               node_size_transform='Math.sqrt',
                               node_size_multiplier=3*scaling_factor,
                               node_border_width=1*scaling_factor,
                               node_font_size = 25*scaling_factor,
                               hover = True,
                               edge_title_field='e_label',
                               edge_width = 3*scaling_factor,
                               edge_color_hover = 'red',
                               edge_color_opacity = 0.5,  
                               physics_enabled=False,
                               min_velocity=.5,
                               min_label_size=12*scaling_factor,
                               max_label_size=25*scaling_factor,
                               max_visible=10*scaling_factor,
                               scaling_factor=scaling_factor,
                               tooltip_delay = 300,
                               graph_id = 1,
                               graph_title = '' )
    
    return(p)


# function show_network is to display network with control slider bar
# df_edges: the original dataframe of edges
# node_cooccur: type II nodes
# s_factor: control the resolution of network, larger number, higher resolution
# key_node: KEY_WORD

def show_network(df_edges, node_cooccur, s_factor = 1, key_node = ['risk factor']):
    output_network = widgets.Output()
    custom_style = {'description_width': 'initial'}
    len_type2_nodesccur = 50
    
    # define a slider to control the number of top type I nodes
    bounded_top_node = widgets.IntSlider(min=1, max=100, value=10, 
                                         step=1, description = "", 
                                         style=custom_style)
    
    # define a slider to control the number of top type II nodes
    bounded_top_key_node = widgets.IntSlider(min=0, max=len_type2_nodesccur, value=1, 
                                         step=1, description = "", 
                                         style=custom_style)
     
    def draw_network(num_t_node, num_t_key_node):
        output_network.clear_output()
        selected_type1_nodes, len_type2_nodesccur, selected_type2_nodes, edge_df_sel = get_sub_edges(df_edges, 
                                                                  node_cooccur = node_cooccur,  
                                                                  top_node = num_t_node, 
                                                                  top_node_cooccur = num_t_key_node)

        network = generate_network(edge_df_sel, k_nodes = selected_type2_nodes, scaling_factor = s_factor)

        with output_network:
            display(network)  


    def bounded_top_node_eventhandler(change):
        draw_network(change.new, bounded_top_key_node.value)
    def bounded_top_key_node_eventhandler(change):
        draw_network(bounded_top_node.value, change.new)

    bounded_top_node.observe(bounded_top_node_eventhandler, names='value')
    bounded_top_key_node.observe(bounded_top_key_node_eventhandler, names='value')

    display(HBox([Label('Number of top words'), bounded_top_node]))
    display(HBox([Label('Number of top words co-occurred with selected key words'), bounded_top_key_node]))
    display(output_network)

    
# function get_df_edges_sents is generete a edge dataframe from three different edges: 
# covid19_edges, covid19_sents_per_edge, pubmed_edges
def get_df_edges_sents(covid19_edges, covid19_sents_per_edge, pubmed_edges):
    uniq_keys = []
    for key in covid19_edges:
        if key not in pubmed_edges:
            uniq_keys.append(key)
            
    uniq_covid19_edges = dict((k, covid19_edges[k]) for k in uniq_keys)
    uniq_covid19_sents_per_edge = dict((k, covid19_sents_per_edge[k]) for k in uniq_keys)
    df_covid19 = pd.DataFrame(uniq_covid19_edges.items())
    df_covid19.columns = ['terms', 'count']
    df_sents_per_edge = pd.DataFrame(uniq_covid19_sents_per_edge.items())
    df_sents_per_edge.columns = ['terms', 'sentences']
    df = pd.merge(df_covid19, df_sents_per_edge, on="terms")
    df[['term1','term2']] = df.terms.str.split('|', expand=True)
    df = df[['term1','term2', 'count', 'sentences']]
    return df


In [None]:
# generete a edge dataframe from three different edges
edge_df = get_df_edges_sents(covid19_edges, covid19_sents_per_edge, pubmed_edges)
edge_df.columns = ['term1', 'term2', 'score', 'sentences']
edge_df['score'] = edge_df['score']/edge_df.score.max()
edge_df['score'] = edge_df['score'].round(2)

## Plot network
The network was displayed with two slider bar, number of top words and number of top words co-occurred with selected key words, to control the number of nodes in Type I and Type II, respectively. 

In [None]:
#################################################
### move the sliders to have network shown up ###
#################################################
show_network(df_edges = edge_df, node_cooccur = cooccur_vocab, s_factor = 1)

Note: We put a snapshot of the network below for your reference because the network can't be display on Kaggle somehow. Again, please go to [our website](https://covid19.insilicom.com/) for the interactive network if the network does not show up.

In [None]:
url=input_path + 'figure/network1.png'
Image(requests.get(url).content, width=800, height=800)