# Understanding Covid 19 with Topic Modeling and Sentence Embeddings

#### Submission for COVID-19 Open Research Dataset Challenge (CORD-19)

## Install and Load Packages

In [None]:
# install scispacy
!pip install scispacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.4/en_core_sci_lg-0.2.4.tar.gz

# install langdetect
!pip install langdetect

In [None]:
# enable widgets
!jupyter nbextension enable --py --sys-prefix widgetsnbextension

In [None]:
# Helper packages.
from IPython.core.display import display, HTML
import os
import pandas as pd
pd.set_option('max_colwidth', 1000)
pd.set_option('max_rows', 100)
import numpy as np
np.set_printoptions(threshold=10000)
import pickle
import matplotlib.pyplot as plt
from datetime import datetime
import re
import json
from tqdm.auto import tqdm
import textwrap
import importlib as imp
from scipy.spatial.distance import cdist
import gc

# Packages with tools for text processing.
# if you have not downloaded stopwords, run the following line
from nltk.corpus import stopwords
from nltk.stem.porter import PorterStemmer
import nltk
nltk.download('stopwords')
import scispacy
import spacy

# Packages for working with text data.
from sklearn.feature_extraction.text import CountVectorizer

# Packages for getting data ready for and building a LDA model
import gensim
from gensim import corpora, models
from gensim.models.coherencemodel import CoherenceModel
from langdetect import detect

# Package for FastText
import fasttext

# Other plotting tools.
import pyLDAvis
import pyLDAvis.gensim
from wordcloud import WordCloud
from IPython.display import display, Markdown, Latex
import ipywidgets as widgets

# Print current directory
print('Current directory is {}'.format(os.getcwd()))

# Extend notebook to full width
display(HTML("<style>.container {width:100% !important; }</style>"))

# Check python version
from platform import python_version
print('Current python version is {}'.format(python_version()))

## Set Directories

In [None]:
input_data_path = '/kaggle/input/CORD-19-research-challenge/'
working_data_path = '/kaggle/input/cov19-pickles/'

## Set Global Variables

In [None]:
source_column = 'text' # if abstract, change to 'abstract'
id_colname = 'cord_uid' # id in metadata for each article
split_sentence_by = '(?<=\\.) ?(?![0-9a-z])' # sentence splitter

## Data Preprocessing

This `meta_full_text` is generated by the next 5 cells:
1.     import `metadata`
1.     parse publish time to datetime object
1.     get full path to .pdf or .pmc
1.     get full text
1.     drop records with empty abstract and empty body text
1.     check duplicated text: most likely due to publications on different journals in which case we keep the latest
1.     check duplicated cord_uid: most likely due to publications on different journals in which case we keep the latest
1.     drop redundant columns and save to pickle

In [None]:
# load from pickle (generated below)
meta_full_text = pickle.load(open(working_data_path + 'all_papers.pkl', 'rb'))

    # 1. import metadata
    metadata = pd.read_csv(input_data_path + 'metadata.csv', encoding='utf-8').replace({np.nan: None})
    print(metadata.shape)
    print('\nNumber of NA for each column: ')
    metadata.isnull().sum(axis=0)

    # 2. parse publish time to datetime object
    def str2time(s):
        try:
            return datetime.strptime(s, '%Y-%m-%d')
        except:
            try:
                return datetime.strptime(s, '%Y %B')
            except:
                try:
                    return datetime.strptime(s, '%Y %b')
                except:
                    try:
                        return datetime.strptime(s, '%Y %B %b')
                    except:
                        try:
                            return datetime.strptime(s, '%Y %b %d') 
                        except:
                            try:
                                return datetime.strptime(s, '%Y')
                            except:
                                return pd.NaT
                            return pd.NaT
                        return pd.NaT
                    return pd.NaT
                return pd.NaT
            return pd.NaT
        return pd.NaT


    metadata['full_text_file_path'] = None
    for i in tqdm(metadata.index):
        row = metadata.iloc[i,:]
        full_text_file_path = []

        # 3. get full path to .pdf or .pmc: prioritize pdf as source, if none, search for pmc
        if row.pdf_json_files: 
            full_text_file_path.extend([path.strip() for path in row.pdf_json_files.split(';')])
        else:
            if row.pmc_json_files:
                full_text_file_path.extend([path.strip() for path in row.pmc_json_files.split(';')])

        if row.publish_time is None: row.publish_time = ''
        publish_time = re.sub(' ([a-zA-Z]{3}-[a-zA-Z]{3})|(Spring)|(Summer)|(Autumn)|(Fall)|(Winter)','', row.publish_time).strip()
        publish_time = str2time(publish_time)

        metadata.loc[i,'publish_time'] = publish_time
        metadata.loc[i,'full_text_file_path'] = full_text_file_path

    # 4. extract full text from JSON files
    def get_paper_info(json_data):
        return ' '.join([t['text'] for t in json_data['body_text']])

    full_text = []
    for r in tqdm(metadata.to_dict(orient='records')):
        record = []
        for p in r['full_text_file_path']:
            with open(input_data_path + p, 'r', encoding='utf-8') as f:
                data = json.load(f)
                record.append(get_paper_info(data))
        full_text_ = '\n'.join(np.unique(record)) if len(record) > 0 else None
        full_text.append(full_text_)
    metadata['full_text'] = full_text

    # 5. drop records with empty abstract AND empty full text
    meta_full_text = metadata
    meta_full_text[source_column]= np.where(meta_full_text['full_text'].isnull(), meta_full_text['abstract'], meta_full_text['full_text'])
    meta_full_text = meta_full_text.dropna(subset = [source_column]).reset_index(drop=True)

    # 6. check duplicated text: most likely due to publications on different journals - in which case we keep the latest one
    print('In total, {} of the rows have a duplicated {} column, and there are a total of {} duplicated {} entries.'.format(sum([len(g) for k, g in meta_full_text.groupby(source_column) if len(g) > 1]), source_column, len([1 for k, g in meta_full_text.groupby(source_column) if len(g) > 1]), source_column))
    meta_full_text = meta_full_text.sort_values('publish_time', ascending=False).drop_duplicates(source_column)

    # 7. check duplicated cord_uid: most likely due to publications on different journals - in which case we keep the latest one
    print('In total, {} of the rows have a duplicated {} column, and there are a total of {} duplicated {} entries.'.format(sum([len(g) for k, g in meta_full_text.groupby(id_colname) if len(g) > 1]), id_colname, len([1 for k, g in meta_full_text.groupby(id_colname) if len(g) > 1]), id_colname))
    meta_full_text = meta_full_text.sort_values('publish_time', ascending=False).drop_duplicates(id_colname)

    # 8. remove metadata from memory to clear space
    del metadata
    gc.collect()

    # 9. drop redundant columns and save to pickle
    meta_full_text.drop(['sha', 'pmcid', 'pubmed_id', 's2_id', 'license', 'mag_id', 'arxiv_id', 'pdf_json_files', 'pmc_json_files', 'full_text_file_path', 'full_text'], inplace=True, axis=1)

    print(meta_full_text.shape)
    print(meta_full_text.columns)
    print('number of unique cord_uid is {}'.format(len(meta_full_text.cord_uid.unique())))

    pickle.dump(meta_full_text, open(working_data_path + 'all_papers.pkl', 'wb'))

## Topic Modeling

In [None]:
print(source_column)

In [None]:
corpus = meta_full_text[source_column]

The first thing that we want to do is just to explore the themes of this corpus - figuring out what the main topics are there in this literature. We can use LDA to do that. Note that the `source_column` that we are using here is column `text` - so mostly full text and in some cases abstract when full text is not available. This should give us a comprehensive coverage of all tokens provided in the literature. 

Here are the steps for this data exploration:
1.     build tokenizer to parse out valid tokens and get their count with `CountVectorizer`
1.     find the optimal number of topics such that overall, tokens are assigned as few topics as possible and documents are assigned as few topics as possible
1.     visualize topic representations for the optimal number of topics

### 1. Parse out valid tokens and get their count

`valid_tokens` and `X` are generated by the next 3 cells

In [None]:
valid_tokens = pickle.load(open(working_data_path + 'TM_valid_tokens.pkl', 'rb')) # valida tokens after parsing
X = pickle.load(open(working_data_path + 'TM_X.pkl', 'rb')) # valid tokens with their count

    # Load SpaCy for lemmatization
    nlp_lg = spacy.load('en_core_sci_lg',disable=['tagger', 'parser', 'ner'])
    nlp_lg.max_length = np.max([len(t) for t in corpus.values])


    # Establish stop words

    # default stop words
    stop_words=stopwords.words('english')

    # custom CORD19 stop words, mostly from Daniel Wolffram's submission "Topic Modeling: Finding Related articles"
    cord_stopwords = ['doi', 'preprint', 'copyright', 'peer', 'reviewed', 'org', 'https', 'et', 'al', 'author', 'figure', 'rights', 'reserved', 
                      'permission', 'used', 'using', 'biorxiv', 'medrxiv', 'license', 'fig', 'fig.', 'al.', 'Elsevier', 'PMC', 'CZI','-PRON-',
                      'abstract']
    # all stop words
    for word in tqdm(cord_stopwords):
        if (word not in stop_words):
            stop_words.append(word)
        else:
            continue

    # update SpaCy stop words list
    for w in tqdm(stop_words):
        nlp_lg.vocab[w].is_stop = True

    # Build tokenizer
    def spacy_tokenizer(sentence):

        # removes substrings before it's tokenized and stemmed
        def removeParenthesesNumbers(v):
            char_list_rm = ['[(]','[)]','[′·]']
            char_list_rm_spc = [' no[nt]-',' non', ' low-', ' high-']
            v = re.sub('|'.join(char_list_rm), '', v)
            v = re.sub('|'.join(char_list_rm_spc), ' ', v)
            return(v)

        sentence = removeParenthesesNumbers(sentence)
        tokenized_list = []
        sentence_letters_only = re.sub('[^a-zA-Z]', '', sentence).strip()

        if sentence_letters_only!="":
            lang = detect(sentence)

            if lang=='en': # only focus on english literature
                # define types of tokens that should be removed using regex
                token_rm = ['(www.\S+)','(-[1-9.])','([∼≈≥≤≦⩾⩽→μ]\S+)','(\S+=\S+)','(http\S+)']
                tokenized_list = [word.lemma_ for word in nlp_lg(sentence) if not (word.like_num or word.is_stop or word.is_punct or word.is_space)]
                tokenized_list = [word for word in tokenized_list if not re.search('|'.join(token_rm),word)]
                tokenized_list = [word for word in tokenized_list if len(re.findall('[a-zA-Z]',word))>1]
                tokenized_list = [word for word in tokenized_list if re.search('^[a-zA-Z0-9]',word)]
        return tokenized_list

    # Test tokenizer
    sentence_test = '($2196.8)/case (in)fidelity μg μg/ml a=b2 www.website.org α-gal 2-len a.'
    spacy_tokenizer(sentence_test)

    # Initialize `CountVectorizer`. Remove common and sparse terms
    vec = CountVectorizer(max_df = .8, min_df = .001, tokenizer = spacy_tokenizer)

    # Transform the list of snippets into DTM.
    X = vec.fit_transform(tqdm(corpus))

    valid_tokens = vec.get_feature_names()

    # pickle
    pickle.dump(X, open(working_data_path + 'TM_X.pkl', 'wb'))
    pickle.dump(valid_tokens, open(working_data_path + 'TM_valid_tokens.pkl', 'wb'))

### 2. find the optimal number of topics

We need `texts` (tokenized texts with repetition), `dictionary` (map from word IDs to words) and `bow_corpus` (count by word ID) before training LDA models, which is generated by the next 4 cells.

In [None]:
np.random.seed(1)
texts = pickle.load(open(working_data_path + 'TM_texts.pkl', 'rb'))

dictionary = gensim.corpora.Dictionary(texts)
bow_corpus = pickle.load(open(working_data_path + 'TM_bow_corpus.pkl', 'rb'))

In [None]:
bow_doc_1 = bow_corpus[0]
print(corpus[corpus.index[0]])
for i in tqdm(range(len(bow_doc_1))):
    print("Word {} (\"{}\") appears {} time.".format(bow_doc_1[i][0], dictionary[bow_doc_1[i][0]],bow_doc_1[i][1]))

*code to generate `text`*

    arr = X.toarray()
    texts = []
    for i in tqdm(range(arr.shape[0])):
        text = []
        for j in range(arr.shape[1]):
            occurrence = arr[i,j]
            if occurrence > 0:
                text.extend([valid_tokens[j]] * occurrence)
        texts.append(text)

    pickle.dump(texts, open(working_data_path + 'TM_texts.pkl', 'wb'))

*code to generate `dictionary` and `bow_corpus`*

    dictionary = gensim.corpora.Dictionary(texts)
    bow_corpus = [dictionary.doc2bow(doc) for doc in texts]
    pickle.dump(bow_corpus, open(working_data_path + 'TM_bow_corpus.pkl', 'wb'))

Before optimizing the model, we delete redundant variables.

In [None]:
del X, valid_tokens
gc.collect()

Now we try to find the optimal model by trying various numbers of topics, from 10 to 20, and comparing their coherence scores `coherence_values`.

In [None]:
limit=20; start=10; step=1;

In [None]:
model_list = pickle.load(open(working_data_path + 'TM_model_list.pkl', 'rb'))
coherence_values = pickle.load(open(working_data_path + 'TM_coherence_values.pkl', 'rb'))

In [None]:
x = range(start, limit, step)
topic_num = x[np.argmax(coherence_values)]

plt.plot(x, coherence_values)
plt.title("Optimal Number of Topics is " + str(topic_num))
plt.xlabel("Num Topics")
plt.ylabel("Coherence score")
plt.legend(("coherence_values"), loc='best')
plt.show()

    def compute_coherence_values(dictionary, corpus, texts, limit, start = 2, step = 3):
        coherence_values = []
        model_list = []
        for num_topics in tqdm(range(start, limit, step)):
            model = gensim.models.LdaMulticore(corpus = corpus, id2word = dictionary, num_topics = num_topics, random_state = 1)
            model_list.append(model)
            coherencemodel = CoherenceModel(model = model, texts = texts, dictionary = dictionary, coherence = 'c_v')
            coherence_values.append(coherencemodel.get_coherence())
            print('Number of topics: {}, Coherence value: {}'.format(num_topics, coherencemodel.get_coherence()))

        return model_list, coherence_values

    model_list, coherence_values = (compute_coherence_values(dictionary = dictionary, 
                                                             corpus = bow_corpus,
                                                             texts = texts, 
                                                             start = start, limit = limit, step = step))

    pickle.dump(model_list, open(working_data_path + 'TM_model_list.pkl', 'wb'))
    pickle.dump(coherence_values, open(working_data_path + 'TM_coherence_values.pkl', 'wb'))

In [None]:
del model_list, coherence_values
gc.collect()

We get the model with the largest coherence score:

In [None]:
lda_model = pickle.load(open(working_data_path+'TM_lda_model.pkl','rb'))

    lda_model = gensim.models.LdaMulticore(bow_corpus, num_topics = topic_num, id2word = dictionary, workers = 4, passes = 2)
    print(lda_model)
    pickle.dump(lda_model, open(working_data_path+'TM_lda_model.pkl','wb'))

### 3. visualize topic representations for the optimal number of topics

Print Topics:

In [None]:
for idx, topic in lda_model.print_topics(-1):
    print('Topic: {} Word: {}'.format(idx, topic))

Visualize topic modeling results:

In [None]:
from IPython.display import HTML
HTML(filename=working_data_path + 'TM_lda_vis.html')

*code to generate visualization*

    vis = pyLDAvis.gensim.prepare(lda_model, bow_corpus, dictionary)
    pyLDAvis.display(vis)
    pyLDAvis.save_html(vis, working_data_path + 'TM_lda_vis.html')

Visualize wordclouds:

In [None]:
cols = ['#029386','#f97306','#ff796c','#cb416b','#fe01b1',
        '#fd411e','#be03fd','#1fa774','#04d9ff','#c9643b',
        '#7ebd01','#155084','#fd4659','#06b1c4','#8b88f8',
        '#029386','#f97306']

In [None]:
topics = lda_model.show_topics(num_words=20,num_topics=topic_num,formatted=False)
cloud = WordCloud(background_color='black',color_func=lambda *args, **kwargs: cols[i],prefer_horizontal=1.0, font_step=1, width=350,height=200)

In [None]:
# Make word clouds for all topics
fig, axes = plt.subplots(3, 6, figsize=(25,10), sharex=True, sharey=True)

for i, ax in tqdm(enumerate(axes.flatten())):
    if i < len(topics):
        fig.add_subplot(ax)
        topic_words = dict(topics[i][1])
        cloud.generate_from_frequencies(topic_words, max_font_size=50)
        plt.gca().imshow(cloud)
        plt.gca().set_title('Topic ' + str(i), fontdict=dict(size=16))
        plt.gca().axis('off')
    else:
        ax.axis('off')

plt.subplots_adjust(wspace=0, hspace=0)
plt.axis('off')
plt.margins(x=0, y=0)
plt.tight_layout()
plt.show()

remove redundant variables before proceeding:

In [None]:
del texts, dictionary, bow_corpus, lda_model
gc.collect()

## Search Covid Literature using `fastText` embeddings

In this last section, we uses `fastText` to generate word embeddings for tokens in the existing literature and calculate sentence embeddings for sentences in covid related literature. We define covid-related literature to be those that satisfy at least one of the conditions below:

In [None]:
# covid earlist date
cov_earliest_date = datetime.strptime('2019-12-01', "%Y-%m-%d")
# covid key terms
cov_key_terms = ['covid\\W19','covid19', 'covid', '2019\\Wncov', '2019ncov', 'ncov\\W2019','sars\\Wcov\\W2', 'sars\\Wcov2', '新型冠状病毒']
# covid related terms
cov_related_terms = '(novel|new)( beta| )coronavirus'

* the paper contains *covid key terms* anywhere in the paper
* the paper contains *covid related terms* anywhere in the paper AND is published after *2019-12-01*
* the paper is marked as *WHO #Covidence* AND, EITHER contains *related terms* OR is published after *2019-12-01*

Our goal in this section is to build a search tool that uses sentence embeddings to rank sentences by their cosine similarity to that of a search term. We proceed as follows:
1. Get covid-related literature and split them into sentences to apply word embeddings upon
2. Train fastText model on the entire literature to get word embeddings
3. Pre-calculate sentence embeddings using those word embeddings 
4. Build search tool to rank sentences by their cosine similarity to the query term

I train the fastText model with `cbow` (instead of `skipgram`) and number of epochs equaling 3:

In [None]:
selected_m = 'cbow'
selected_epoch = 3

### 1. Get covid-related literature and split them into sentences to apply word embeddings upon

We generate word embeddings from sentences in **all text** - full text, or abstract if full text is not available.

In [None]:
print(source_column)

selected_text = 'raw_' + source_column
model_name_suffix = selected_m + '_' + selected_text + '_epoch' + str(selected_epoch)

We calculate sentence embeddings for sentences in abstract, so all of our search results are sentences from abstract only. We can switch to `'text'` if we want to search all text. 

In [None]:
search_column = 'abstract' 

search_text = 'raw_' + search_column
search_name_suffix = selected_m + '_' + search_text + '_epoch' + str(selected_epoch)

First, we get covid-related literature:

In [None]:
def get_covid19(data):
    cov_key_terms_mask = data[source_column].str.lower().str.contains('|'.join(cov_key_terms))
    cov_related_terms_mask = data[source_column].str.lower().str.contains(cov_related_terms)

    data['WHO_covidence'] = False
    data.loc[~data['who_covidence_id'].isnull(), 'WHO_covidence'] = True

    data['contain_key_terms'] = False
    data.loc[cov_key_terms_mask,'contain_key_terms'] = True

    data['contain_related_terms'] = False
    data.loc[cov_related_terms_mask,'contain_related_terms'] = True

    data['after_earliest_date'] = False
    data.loc[data.publish_time>= cov_earliest_date,'after_earliest_date'] = True

    covid19 = data[data.contain_key_terms | (data.contain_related_terms & data.after_earliest_date) | (data.WHO_covidence & (data.contain_related_terms | data.after_earliest_date))]
    covid19.reset_index(drop=True, inplace=True)
    print("There are a total number of {} papers satisfying the above definition".format(len(covid19)))
    return covid19

In [None]:
covid19 = get_covid19(meta_full_text)
print(covid19.shape)
covid19[:1]

remove redundant variables:

In [None]:
del meta_full_text
gc.collect()

Next we split them into sentences, and for each sentence we generate a dictionary to lookup paper-level info. 

In [None]:
sents_in_paper = pickle.load(open(working_data_path + 'fasttext_model_' + search_column + '_sents_in_paper.pkl', 'rb'))
paper_lookup = pickle.load(open(working_data_path + 'fasttext_model_' + search_column + '_paper_lookup.pkl', 'rb'))

    def sentence_to_paper(df, id_colname, text_colname, topic_colname_prefix, split_sentence_by):
        # link sentences to a paper: sents_in_paper
        sents_in_paper = dict()
        papers = [(paper[id_colname], paper[text_colname]) if paper[text_colname] is not None else (paper[id_colname], "") for paper in df.to_dict(orient='row')]
        sents = [(paper[0], re.split(split_sentence_by, paper[1])) for paper in papers]
        sent_order = 1
        for pair in np.concatenate([list(zip(id, sent)) for id, sent in [([sent[0]]*len(sent[1]),sent[1]) for sent in sents]]):
            sent = pair[1]
            if sent not in sents_in_paper:
                sents_in_paper[sent] = (pair[0], sent_order)
                sent_order += 1

        # lookup paper information: paper_lookup        
        paper_lookup = dict()
        for paper in df.to_dict(orient='records'):
            id = str(paper[id_colname])
            if id not in paper_lookup:
                paper[topic_colname_prefix] = dict((k, paper[k]) for k in paper.keys() if k.startswith(topic_colname_prefix))
                paper_lookup[id] = paper    

        return sents_in_paper, paper_lookup


    sents_in_paper, paper_lookup = sentence_to_paper(covid19, 
                                                     id_colname=id_colname, 
                                                     text_colname=search_column, 
                                                     topic_colname_prefix='topic', 
                                                     split_sentence_by=split_sentence_by)

    pickle.dump(sents_in_paper, open(working_data_path + 'fasttext_model_' + search_column + '_sents_in_paper.pkl', 'wb'))
    pickle.dump(paper_lookup, open(working_data_path + 'fasttext_model_' + search_column + '_paper_lookup.pkl', 'wb'))

remove redundant variables:

In [None]:
del covid19
gc.collect()

### 2.Train fastText model on the entire literature to get word embeddings

In [None]:
model = fasttext.load_model(working_data_path + 'fasttext_model_' + model_name_suffix)

    # create file with individual sentence on each line
    file = open(working_data_path + 'fasttext_model_' + source_column + '_by_sentence.txt', 'w', encoding='utf-8')
    for txt in filter(None, corpus.values):
        file.write('\n'.join(re.split(split_sentence_by, txt)))
    file.close()

    # run model
    model = fasttext.train_unsupervised(working_data_path + 'fasttext_model_' + source_column + '_by_sentence.txt', 
                                        model = selected_m, 
                                        epoch = selected_epoch)
    model.save_model(working_data_path + 'fasttext_model_' + model_name_suffix)

In [None]:
emb_len = len(model.get_output_matrix()[0])

### 3. Pre-calculate sentence embeddings using those word embeddings

In [None]:
X = pickle.load(open(working_data_path + 'fasttext_model_' + search_name_suffix + '_X.pkl', 'rb'))

    # get sentence embeddings
    X = pd.DataFrame(pd.np.empty((len(list(sents_in_paper.keys())), emb_len)))

    i = 0
    for sent in tqdm(list(sents_in_paper.keys())):
        X.iloc[i] = model.get_sentence_vector(sent)
        i+=1

    pickle.dump(X, open(working_data_path + 'fasttext_model_' + search_name_suffix + '_X' + '.pkl', 'wb'))

### 4. Build search tool to rank sentences by their cosine similarity to the query term

In [None]:
# create search
class Quicksearch:
    def __init__(self, modl, emb_len, sentences, sentence_embeddings, paper_lookup):
        self.modl = modl
        self.emb_len = emb_len
        self.sentences = sentences
        self.sentence_embeddings = sentence_embeddings
        self.paper_lookup = paper_lookup
    def get_candidate_ranking(self, sent):
        y = self.modl.get_sentence_vector(sent)
        scores = cdist(self.sentence_embeddings,[y],'cosine').ravel()
        return list(zip(list(self.sentences.keys()), scores))
    def term(self, init, placeholder, description):
        return widgets.Textarea(value=init, 
                                placeholder=placeholder, 
                                description=description, 
                                layout=widgets.Layout(width='90%', display='flex'))
    def sort(self, init, options, description):
        return widgets.Dropdown(options=options,
                                  value=init,
                                  description=description, 
                                  layout=widgets.Layout(width='90%', display='flex'))
    def top(self, init, maxx, description):
        return widgets.IntSlider(min=1, 
                                 max=maxx, 
                                 value=init, 
                                 description=description, 
                                 layout=widgets.Layout(width='90%', display='flex'))
    def search(self, term, sort_by, show_top):
        if term == '':
            print('')
        else:
            term = term.lower()
            sent_rank, paper_rank, final_result = [], dict(), []
            
            # get ranking for search results
            ranked_sentences = sorted(self.get_candidate_ranking(term), key=lambda x:x[1])

            # for each sentence, record content, rank, order in paper
            # for each paper, record highest ranked sentence
            for i, ranked_sentence in enumerate(ranked_sentences):
                if i < show_top:
                    sentence = ranked_sentence[0]
                    score = ranked_sentence[1]

                    r = dict()
                    r['rank'] = i + 1
                    r['sentence'] = sentence
                    r['paper id'] = self.sentences[sentence][0]
                    r['sentence_order'] = self.sentences[sentence][1]
                    sent_rank.append(r)
                    
                    #record highest ranking sentence
                    if self.sentences[sentence][0] not in paper_rank: 
                        paper_rank[self.sentences[sentence][0]] = i + 1
    
            # for each paper, lookup information on that paper
            for key, group in pd.DataFrame(sent_rank).groupby('paper id'):
                r = dict()
                r['rank'] = paper_rank[key]
                r['publish_time'] = self.paper_lookup[key]['publish_time']
                r['title'] = self.paper_lookup[key]['title']
                r['journal'] = self.paper_lookup[key]['journal']
                r['url'] = self.paper_lookup[key]['url']
                r['topic'] = self.paper_lookup[key]['topic']
                r['sentences'] = [sent for sent, order in sorted(zip(group['sentence'].values, group['sentence_order'].values), key=lambda r: r[1])]
                final_result.append(r)
            final_result = pd.DataFrame(final_result)
            
            # print search results
            if_ascend = False if sort_by == 'publish_time' else True
            
            print('Search Results for ' + '\033[1m' + '"' + term.upper() + '\033[0m' + '"\n')
            
            for k,r in final_result.sort_values(by=[sort_by], ascending=if_ascend).iterrows():
                r['url'] = "" if r['url'] is None else r['url']
                r['journal'] = 'NA' if r['journal'] is None else r['journal']
                r['publish_time'] = '' if pd.isnull(r['publish_time']) else datetime.strftime(r['publish_time'], '%Y-%m-%d')
                r['sentences'] = '...'.join(r['sentences'])
                
                print('\033[1m' + r['title'] + '\033[0m')
                print('\033[1m' + 'Results: ' + '\033[0m' + r['sentences'])
                print('\033[1m' + 'Publish Time: ' + '\033[0m' + r['publish_time'])
                print('\033[1m' + 'Journal: ' + '\033[0m' + r['journal'])
                print('\033[1m' + 'Link: ' + '\033[0m' + r['url'])
                print('\n')

In [None]:
# define quicksearch
quicksearch = Quicksearch(model, emb_len, sents_in_paper, X, paper_lookup)

# set up init options
init_show = 10
init_max = 100
init_sort = 'publish_time'
init_search = 'incubation period'
init_options = {'Most Recent': 'publish_time', 'Most Similar': 'rank'}

# set up widget
term = quicksearch.term(init=init_search, placeholder='', description='Search: ')
sort_by = quicksearch.sort(init=init_sort, options=init_options, description='Sort By: ')
show_top = quicksearch.top(init=init_show, maxx=init_max, description='Filter # of Sentences to Show: ')
show_top.style.handle_color='darkred'
term.style.description_width = '100px'
sort_by.style.description_width = '100px'
show_top.style.description_width = '180px'

search = widgets.interactive(quicksearch.search, 
                             term = term, 
                             sort_by = sort_by, 
                             show_top = show_top)

In [None]:
display(search)