In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import json
from IPython.display import Image
from IPython.core.display import HTML
import re
from re import finditer
import nltk
import spacy
from nltk.stem import PorterStemmer
from collections import defaultdict

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

# Any results you write to the current directory are saved as output.

# Identifying Risk Factors associated with COVID-19
#### In this task we want to identify unique risk factors associated with COVID-19 from a corpus of research papers. A challenging task here is to automatically consolidate findings from these papers in categories of risk factors. This notebook is built on the intuition that:
* Paragraphs that discuss risk factors likely also mention the disease it is associated with. E.g. consider the extract below:

![](https://i.ibb.co/Fzzg93Z/covid-pic2.png)

* A simple heuristic here is to find sentences in paragraphs that contain one of the names of the COVID-19 disease along with one of the risk factors of interest (e.g. "smoking", "pregnancy"). Sentences - or small groups of sentences - that meet both these criteria can be considered valuable information to extract.

* Once we have identified paragraphs or sentences that are relevant, we want to filter down and tag these by type of factors for ease of categorizing.
---

![](https://i.ibb.co/j5r1v0F/covid-pic4.png)

In [None]:
base_path = "/kaggle/input/CORD-19-research-challenge/"
sources = pd.read_csv(base_path + "metadata.csv",
                     dtype={"pubmed_id":str,
                           "Microsoft Academic Paper ID":str})
sources.head()

# Specify the sets of unigrams to look for.
#### Separately specify the set of unigrams for Risk Factors (e.g. "smoke", "pregnancy") as well as the unigram variants of how authors discuss COVID-19 (e.g. "COVID-19", "2019-nCoV", "SARS-CoV-2")

In [None]:
'''
set the global variable.

sent_tokens : for each paper_id, keep a cache of the sentence and word tokenized and stemmed paragraphs.
sent_fulls  : for each paper_id, keep a cache of the sentence but NOT word tokenized paragraphs.

jdict       : for each paragraph segment found to contain valuable data, store the paper_id and desired segment of text.
'''

sent_tokens = defaultdict(lambda: defaultdict(lambda: "")) # for each document, cache the tokenized sentences for easy revisits
sent_fulls = defaultdict(lambda: defaultdict(lambda: ""))



valid_ids = set(sources[sources["has_full_text"]==True]["sha"].unique().tolist())

stemmer = PorterStemmer()

targs = [set({"smoke", "pulmonary", "pre-existing", "neo-natal", "natal", "pregnancy", "pregnant", "economic", "social", "socio-economic"}),
         set({"covid-19", "covid19", "sars-cov-2", "2019-ncov"})]

reverse_map = {}
# convert our target terms into their stemmed versions for compatibility in the matching stage
for i in range(0, len(targs)):
    newterms = set()
    for ele in targs[i]:
        st = stemmer.stem(ele)
        newterms.add(st)
        reverse_map[st] = ele
    targs[i] = newterms

# Process each paragraph in the JSON file.

In [None]:

'''
validate_segment() is our main function. Each paragraph in each JSON file is passed to this function.
We process and tokenize the text and then look for sentences that mention the desired terms.
Since we are also interested in numeric data, we specifically put a filter to only include segments that have numeric values in them.
'''

def validate_segment(segment, paper_id=None, cnt=None):
    global targs
    '''
    so the thinking here is that a paragraph that mentions a risk factor related to COVID-19 will mention both in short word-order proximity.
    Particularly, we will work with the hypothesis that the mention of the risk factor and mention of COVID will be no greater than 2
    sentences apart.
    These are the sentences we will keep.
    '''
    
    # quick heuristic to get rid of paragraphs that don't even discuss COVID-19 (or SARS-CoV-2)
    if not "19" in segment and not "-2" in segment:
        return False, "", set()
        
    # first convert this into tokens
    jtxt = None
    # check if we have already cached the tokenized paragraph.
    # if so, just pick it up and move on.
    if paper_id:
        if paper_id in sent_tokens:
            if cnt in sent_tokens[paper_id]:
                jtxt = sent_tokens[paper_id][cnt]
                jtxt_base = sent_fulls[paper_id][cnt]
            
            
    # if this particular paragraph has not already been cached,
    # perform sentence and word tokenization as well as stemming.
    # then cache it for much faster subsequent processing.
    if jtxt is None:
        jtxt_base = nltk.sent_tokenize(segment)
        jtxt = [[stemmer.stem(y.lower()) for y in nltk.word_tokenize(x)] for x in jtxt_base]
        if not paper_id is None:
            sent_tokens[paper_id][cnt] = jtxt
            sent_fulls[paper_id][cnt] = jtxt_base

    
    
    # for each sentence, determine if the two categories of targets have been matched. If not, try checking the preceding
    # and succeeding sentence.
    
    sent_founds = []
    sent_numerics = []
    for i in range(0, len(jtxt)):
        # for each sentence, check if at least one number is mentioned (stats)
        
        # don't count citations as numeric values (e.g. "according to [12,13] etc.")
        no_bracks = re.sub(r"\[\s*\d+((\s*\,\s*\d+)+)?\]", "", jtxt_base[i])
        matchers = re.search(r"[^A-Z-a-z0-9\-](\d+)[^A-Z-a-z0-9\-]", no_bracks)
        is_numeric = False
        if matchers:
            # as a simple heuristic, we ignore values that might be years.
            # highly unlikely these values will be less than 1900 or greater than 2020.
            if int(matchers.group(1)) < 1900 or int(matchers.group(1)) > 2020:
                is_numeric = True
                    
        # for each sentence, check if any of the words are target words
        tempy = set()
        for k in range(0, len(jtxt[i])):
            word = jtxt[i][k]                    
            # check for match
            for q in range(0, len(targs)):
                if word in targs[q]:
                    tempy.add(q)
        sent_numerics.append(is_numeric)
        sent_founds.append(tempy)
    
    
    # we now thave the list of found words. now let's run the heuristic.
    # for each sentence, we check if all terms were located. If not, then we check if the missing terms were in either the preceding
    # of following sentence.
    val_sent = None
    val_tags = None
    tagset = set()
    for i in range(0, len(sent_founds)):
        if len(sent_founds[i])==len(targs):
            if sent_numerics[i]:
                val_sent = jtxt_base[i]
                val_tags = jtxt[i]
                break
        
        # at least one target is missing. check the neighbors
        is_numeric = sent_numerics[i]
        tempset = sent_founds[i].copy()
        if i > 0:
            tempset.update(sent_founds[i-1])
            is_numeric = True if sent_numerics[i] or sent_numerics[i-1] else False
            if len(tempset)==len(targs) and is_numeric:
                val_sent = jtxt_base[i-1] + " " + jtxt_base[i]
                val_tags = jtxt[i] + (jtxt[i-1])
                break
                
        is_numeric = sent_numerics[i]
        tempset = sent_founds[i].copy()
        if i < (len(sent_founds) - 1):
            tempset.update(sent_founds[i+1])
            is_numeric = True if sent_numerics[i] or sent_numerics[i+1] else False
            if len(tempset)==len(targs) and is_numeric:
                val_sent = jtxt_base[i] + " " + jtxt_base[i+1]
                val_tags = jtxt[i] + (jtxt[i+1])
                break          
    
    if not val_sent:
        return False, "", set()
    
    # find the set of tags that were matches
    matchset = set()
    vbase = val_tags
    val_tags = set(val_tags)
    for q in range(0, len(targs)-1):
        matchset = matchset.union(targs[q])
    val_tags = val_tags.intersection(matchset)
    
    return True, val_sent, val_tags

# Find the Basic Reproduction Number (R_0)
- Heuristic: look for string "basic reproduction number followed closely by a number that is "reasonable": above 0 but below 12.
- This heuristic helps avoid cases of mistakenly identifying citation numbers that appear near mentions of R_0 as the actual values.

In [None]:
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False


BRN = defaultdict(lambda: [])


'''
given the raw paragraph text, see if there is mention of the basic reproduction number and what that number is.
'''
def find_reproduction_number(segment, paper_id=None, cnt=None, paper_title=None):
    # quick heuristic to get rid of paragraphs that don't even discuss COVID-19 (or SARS-CoV-2)
    if not "19" in segment and not "-2" in segment:
        return False, "", set()
        
    # first convert this into tokens
    jtxt = None
    # check if we have already cached the tokenized paragraph.
    # if so, just pick it up and move on.
    if paper_id:
        if paper_id in sent_tokens:
            if cnt in sent_tokens[paper_id]:
                jtxt = sent_tokens[paper_id][cnt]
                jtxt_base = sent_fulls[paper_id][cnt]
            
            
    # if this particular paragraph has not already been cached,
    # perform sentence and word tokenization as well as stemming.
    # then cache it for much faster subsequent processing.
    if jtxt is None:
        jtxt_base = nltk.sent_tokenize(segment)
        jtxt = [[stemmer.stem(y.lower()) for y in nltk.word_tokenize(x)] for x in jtxt_base]
        if not paper_id is None:
            sent_tokens[paper_id][cnt] = jtxt
            sent_fulls[paper_id][cnt] = jtxt_base    
            
    # check if this paragraph is talking about COVID-19
    tempy = set()
    is_relevant = False
    for i in range(0, len(jtxt)):
        for k in range(0, len(jtxt[i])):
            word = jtxt[i][k]                    
            # check for match
            if word in targs[-1]:
                is_relevant = True
                break
    if not is_relevant:
        return
            
    
    # for each sentence in the unsplit paragraph, check if there is mention of the basic reproduction number
    # also check that it is still in the context of covid-19.
    
    
    # format the paragraph text to be alphanumeric, with at most one consecuteve space
    # and entirely lower case.
    # this makes it easier to search for strings.
    brn_stem = [stemmer.stem(x) for x in "basic reproduction number".split(" ")]
    for q,sent in enumerate(jtxt):
        for i in range(2, len(sent)):
            if sent[i]==brn_stem[-1] and sent[i-1]==brn_stem[-2] and sent[i-2]==brn_stem[-3]:
                # we found a match. is there a number that follows?
                for k in range(0, 3):
                    if (i + k) >= len(sent):
                        break
                    if is_number(sent[i + k]):
                        if float(sent[i+k])==0:
                            continue
                        if float(sent[i+k]) > 12:
                            # highly unlikely and is likely a broken reference bracket
                            continue
                        # we found a numeric match as well.
                        BRN[paper_id].append({"sentence": jtxt_base[q], "value": float(sent[i+k]),
                                             "paper_id": paper_id, "title":paper_title})
                        
    #heuristics = [r"basic reproduction number (?:(?:\w+\s+){1,3})(\d+(?:\.\d+)?)"]
    
#     for sent in jtxt_base:
#         seglow = sent.lower()
#         seglow = re.sub(r"[^A-Za-z0-9 ]", "", seglow)
#         seglow = re.sub(r" {2,}", " ", seglow)
        
#         all_matches = [m.start(0) for m in re.finditer(target, seglow)]
#         if len(all_matches) > 0:
#             BRN.append({"sentence": sent, ""})

# For each JSON file, send body text paragraph to validate_segment()

In [None]:
def process_file(jstruct):
    if "paper_id" in jstruct:
        if jstruct["paper_id"] in valid_ids:
            # consolidate the document text and see if there's a match.
            jbod = jstruct["body_text"]
            temp = defaultdict(lambda x: "")
                        
            #.... for now, let's keep things simple. We assume that if risk factors are mentioned, it will be at the paragraph level
            
            # loop through each paragraph
            for cnt, x in enumerate(jbod):
                is_valid, val_seg, val_tags = validate_segment(x["text"], jstruct["paper_id"], cnt)
                find_reproduction_number(x["text"], jstruct["paper_id"], cnt, jstruct["metadata"]["title"])
                if is_valid:
                    #print ("found match: {} : {} => {}".format(jstruct["paper_id"], val_seg, val_tags))
                    jdict[jstruct["paper_id"]].append({"text":x["text"], "tags":val_tags, "segment":val_seg, "paper_id":jstruct["paper_id"], "title":jstruct["metadata"]["title"]})

# Loop through each JSON file. Send it to process_file()

In [None]:

# jdict will hold the tokenized sentences for each paragraph in each document
jdict = defaultdict(lambda:[])

# BRN holds the list of sentences that mention a COVID-19 basic reproduction number 
BRN = defaultdict(lambda: [])



counter = 0
file_list = []
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        #print(os.path.join(dirname, filename))
        if filename[-5:]==".json":
            file_list.append(os.path.join(dirname, filename))

file_list.sort()
total_files = len(file_list)

useds = set()
for file in file_list:
    #if counter > 1000:
    #    break
    process_file(json.load(open(file, "r")))
    counter += 1
    perc_complete = round((counter/total_files)*100)
    if perc_complete%5==0:
        if perc_complete in useds:
            continue
        useds.add(perc_complete)
        print ("{} / {} => {}% complete".format(counter, total_files, perc_complete))

# Results completed. Display the findings.
- For each of the Risk Factor tags, compile all the segments we found.
- Color-code the key terms for easier understanding

In [None]:
all_brns = set()
for paper_id, found_objs in BRN.items():
    for entry in found_objs:
        all_brns.add(entry["value"])
all_brns = sorted(all_brns)

brn_objs = []
for paper_id, found_objs in BRN.items():
    for entry in found_objs:
        brn_objs.append({"value":entry["value"], "sentence": entry["sentence"], "title": entry["title"]})

brn_sort = sorted(brn_objs, key=lambda x: x["value"], reverse=False)

htmlstr = "<span style='font-weight:bold;font-size:18px;'>Basic Reproduction Numbers (R<sub>0</sub>)</span><br />"
htmlstr += "<span style='font-size:16px;'>"
htmlstr += "<span style='font-weight:normal;padding-right:20px;'>Average R<sub>0</sub>: </span><span>{}</span><br />".format(round(np.mean(all_brns),3))
htmlstr += "<span style='font-weight:normal;padding-right:20px;'>Median R<sub>0</sub>: </span><span>{}</span><br />".format(round(np.median(all_brns),3))
htmlstr += "<span style='font-weight:normal;padding-right:20px;'>Minimum R<sub>0</sub>: </span><span>{}</span><br />".format(min(all_brns))
htmlstr += "<span style='font-weight:normal;padding-right:20px;'>Maximum R<sub>0</sub>: </span><span>{}</span><br />".format(max(all_brns))
htmlstr += "<br /><span style='font-size:18px;font-weight:bold;'>All R<sub>0</sub> values found</span><br />"
htmlstr += "<span>{}</span>".format(",  ".join(map(str, all_brns))) + "<br />"
htmlstr += "<br /><span style='font-size:18px;font-weight:bold;'>References</span><br />"

tempstr = "<br /><div style='display:table;'>"
tempstr += "<div style='display:table-row;'>\
<div style='display:table-cell;'>&nbsp;</div>\
<div style='display:table-cell;font-weight:bold;font-size:18px;padding-bottom:10px;'>R<sub>0</sub></div>\
<div style='display:table-cell;min-width:50px;'>&nbsp;</div>\
<div style='display:table-cell;font-weight:bold;font-size:18px;'>Extract</sub></div>\
</div>"
for result in brn_sort:
    tempstr += "<div style='display:table-row;'>"
    tempstr += "<div style='display:table-cell;padding-right:30px;font-size:20px;'>•</div>"
    tempstr += "<div style='display:table-cell;'>{}</div>".format(result["value"])
    tempstr += "<div style='display:table-cell;'></div>"    
    tempstr += "<div style='display:table-cell;'>{}<span style='color:#0099cc'>[{}]</span></div>".format(result["sentence"], result["title"])
    tempstr += "</div>"
    
tempstr += "</div>"
htmlstr += tempstr

htmlstr += "</span>"
display(HTML(htmlstr))


In [None]:

topics = defaultdict(lambda: {"text":[], "title":[], "rawtag":""})

for paper_id, found_objs in jdict.items():
    
    for ele in found_objs:
        
        # for each tag (usually only one) see which topic this falls under
        for tag in ele["tags"]:
            topics[reverse_map[tag]]["text"].append(ele["segment"])
            topics[reverse_map[tag]]["title"].append(ele["title"])
            topics[reverse_map[tag]]["rawtag"] = tag
            
htmls = defaultdict(lambda: "")
for topic_name in topics:
    htmlstr = "<div class='test_output'>"
    htmlstr += "<br /><div style='font-weight:bold;'>{}</div><br />".format(topic_name)
    htmlstr += "<div style='display:table;'>"
    
    for q, entry in enumerate(topics[topic_name]["text"]):
        splinter = nltk.word_tokenize(entry)
        
        for i in range(0, len(splinter)):
            if stemmer.stem(splinter[i])==topics[topic_name]["rawtag"]:
                splinter[i] = "<span style='background-color:#FFCC33;'>" + splinter[i] + "</span>"
            elif stemmer.stem(splinter[i]) in targs[-1]:
                splinter[i] = "<span style='background-color:#FF99FF;'>" + splinter[i] + "</span>"
                
        formatted = " ".join(splinter) + "<span style='color:#0099cc;'> [" + topics[topic_name]["title"][q] + "]</span>"
        htmlstr += "<div style='display:table-row;'>"
        htmlstr += "<div style='display:table-cell;padding-right:15px;font-size:20px;'>•</div><div style='display:table-cell;'>" + formatted + "</div>"
        htmlstr += "</div>"
        
    htmlstr += "</div>"
    htmlstr += "</div>"
    htmls[topic_name] = htmlstr

In [None]:
display(HTML(htmls["social"]))

In [None]:
display(HTML(htmls["smoke"]))

In [None]:
display(HTML(htmls["pregnant"]))
display(HTML(htmls["pregnancy"]))

In [None]:
display(HTML(htmls["pre-existing"]))

In [None]:
display(HTML(htmls["economic"]))

In [None]:
display(HTML(htmls["neo-natal"]))
display(HTML(htmls["natal"]))