In [None]:
import torch
import os
path = "/kaggle/input/papers-for-biobert/testset/"
files = os.listdir(path)
papers = []
for f in files:
    papers += torch.load(path + f)

In [None]:
from ipywidgets import interact_manual, widgets

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

def clean_text(text):
    tokens = word_tokenize(text)

    start = -1
    for i, token in enumerate(tokens):
        if start<0:
            if not(token.isdigit()) and not (token in (string.punctuation)):
                start = i
    tokens = tokens[start:]
    sentence = " ".join(tokens)
    return sentence

In [None]:
%%capture
# BioBERT dependencies
import subprocess
# Tensorflow 2.0 didn't work with the pretrained BioBERT weights
!pip install tensorflow==1.15
# Install bert-as-service
!pip install bert-serving-server==1.10.0
!pip install bert-serving-client==1.10.0

# We need to rename some files to get them to work with the naming conventions expected by bert-serving-start
!cp /kaggle/input/biobert-pretrained /kaggle/working -r
%mv /kaggle/working/biobert-pretrained/biobert_v1.1_pubmed/model.ckpt-1000000.index /kaggle/working/biobert-pretrained/biobert_v1.1_pubmed/bert_model.ckpt.index
%mv /kaggle/working/biobert-pretrained/biobert_v1.1_pubmed/model.ckpt-1000000.data-00000-of-00001 /kaggle/working/biobert-pretrained/biobert_v1.1_pubmed/bert_model.ckpt.data-00000-of-00001
%mv /kaggle/working/biobert-pretrained/biobert_v1.1_pubmed/model.ckpt-1000000.meta /kaggle/working/biobert-pretrained/biobert_v1.1_pubmed/bert_model.ckpt.meta

In [None]:
%%time

bert_command = 'bert-serving-start -model_dir /kaggle/working/biobert-pretrained/biobert_v1.1_pubmed -max_seq_len=None -max_batch_size=32 -num_worker=2'
process = subprocess.Popen(bert_command.split(), stdout=subprocess.PIPE)

# Start the BERT client. It takes about 10 seconds for the bert server to start, which delays the client
from bert_serving.client import BertClient

bc = BertClient()

In [None]:
#%%time
from scipy.spatial.distance import cosine
import pickle
import numpy as np
from nltk.corpus import stopwords 
from nltk.tokenize import word_tokenize 
import string
import string 


with open('/kaggle/input/summary-embeddings/summary_embs_df.pkl','rb') as f:
    emb_df = pickle.load(f)


def cosine_distance(v1, v2):
    distance = 1 - cosine(v1, v2)
    return distance

def answer_query(query,num_summaries=5):
    ##Encode the query with biobert
    qemb = bc.encode([query])
    
    relevant_embeddings = emb_df

    ## Compute similarities with relevant embeddings and querry
    a = np.array([cosine_distance(qemb[0],relevant_embeddings['embedding'][i]) for i in range(relevant_embeddings.shape[0])])
    asort = np.argsort(a)
    
    ## Print everything
#     print('')
    print('Generated summaries of '+str(num_summaries)+' most relevant papers for query:')
    print('"'+query+'"')
    for i in range(1,num_summaries):
        print('-----------------')
        print("From paper with Title : "+relevant_embeddings['paper_id'][asort[-i]])
        print("Important sentences : ")
        soum = relevant_embeddings['summary'][asort[-i]]
        sents = soum.split('.')      
        for s in sents:
            print (clean_text(s))    
            print('...')
        print('With average sentence importance score : '+str(relevant_embeddings['sum_score'][asort[-i]]))

In [None]:
## Given a query, returns the answer :)
def answer_query(query,bc,filter_dataset=False,use_synonyms=False,count=5):
    print (bcolors.OKCYAN + "Answers for the query: " + bcolors.ENDC)
    print (bcolors.OKBLUE + query + bcolors.ENDC)
    print(bcolors.OKGREEN  + "----------------------------------------------------------------------------------" + bcolors.ENDC)
    #model = create_bert()
    
    #print(emb_df.head())
    #print("IN ANSWER QUERY COUNT IS "+str(count))
    if use_synonyms:
        queries = get_synonymous_queries(query)
    else:
        queries = list()
        queries.append(query)


    if filter_dataset:
        relevant_embeddings = snomed_filter(query=query,count=count)
        print("Retrieved "+str(relevant_embeddings.shape[0])+ ' papers with more than '+str(count)+' appearances of relevant SNOMED terms to the query.')
    else:
        relevant_embeddings = emb_df

    #print(type(queries))
    similar_texts = get_similar_texts(queries,[relevant_embeddings['embedding'][i] for i in range(relevant_embeddings.shape[0])],bc = bc)
    
#     print(similar_texts)
#     print(type(similar_texts))
    for i in range(5):
        pid = relevant_embeddings['paper_id'][similar_texts[i]]
#         title = md[(md['s2_id']==pid) | (md['who_covidence_id']==pid) | (md['cord_uid']==pid) | (md['pubmed_id']==pid) | (md['arxiv_id']==pid) | (md['pmcid']==pid) | (md['sha']==pid) | (md['mag_id']==pid)]['title'].reset_index(drop=True)
        print(bcolors.OKCYAN + 'Paper ID : ' + bcolors.ENDC + bcolors.FAIL + str(pid) + bcolors.ENDC)
#         if len(title)<1:
#             print("NO TITLE !?")
#         else:
#             print('Title: '+str(title[0]))
        print(bcolors.OKCYAN + 'Summary/Abstract: ' + bcolors.ENDC)
        print(relevant_embeddings['summary'][similar_texts[i]])
        print(bcolors.OKGREEN  + "----------------------------------------------------------------------------------" + bcolors.ENDC)



## Given a query, returns embeddings of papers which contain relevant terms to the query according to SNOMED
def snomed_filter(query,count=1):
    print("IN SNOMED FILTER COUNT IS "+str(count))
    ## split_query into words and remove stopwords
    stop_words = set(stopwords.words('english')) 
    word_tokens = word_tokenize(query)
    words = [w for w in word_tokens if w not in stop_words]
    
    words=[word.lower() for word in words if word.isalpha()]
    words = [w for w in words if w not in ['is','are']]
    ## search all words for relevant snomed ids
    rel_ids = [find_relevant_papers(word) for word in words]

    ## check for intersection of paper ids with more than 20 papers
    rel_ids_flat = [item for sublist in rel_ids for item in sublist]
    relevant_papers = list(set([i for i in rel_ids_flat if rel_ids_flat.count(i)>count]))
    

    ## if intersection is very small just return the union
    if len(relevant_papers)<20:
        #print("WARNING: The intersection of snomed ids was too small, searching in union")
        relevant_papers = list(set(rel_ids_flat))
        
    #return relevant_papers 
    ## Finally, if union is still small, relevant embeddings are all papers
    if len(relevant_papers)<20:
        print("WARNING: The union of snomed ids was too small, searching all papers")
        relevant_embeddings = emb_df
    else:
        relevant_embedding_indexes = [i for i in range(emb_df.shape[0]) if emb_df['paper_id'][i] in relevant_papers]
        relevant_embeddings = emb_df.iloc[relevant_embedding_indexes].reset_index(drop=True)

    return relevant_embeddings


## Given a single term, searches for papers in which a term or a relevant term appears
def find_relevant_papers(search_term):
    
    ## Get rows with relevant term in 'snomed_superclasses' column
    relevant_words = snomed_words[snomed_words['snomed_superclasses'].str.contains(search_term,case=False)].reset_index(drop=True)
    ## Get rows with relevant term in 'snomed_descriptions' column
    relevant_words = relevant_words.append(snomed_words[snomed_words['snomed_descriptions'].str.contains(search_term,case=False)]).reset_index(drop=True)
    #print(" relevant snomed IDs to word '"+search_term+"'")
#     print("\t"+str(len(relevant_words))+" terms from SNOMED are relevant to '"+search_term+"' and appear in the corpus:")
    for r in relevant_words['word']:
        print("\t\t"+r)
    ## Get the SNOMED ids contained in retrieved rows
    relevant_ids = [s[2:-2].replace('"','').replace("'",'').split(', ') for s in relevant_words['snomed_ids']]
    ## Flatten list
    relevant_ids = [item for sublist in relevant_ids for item in sublist]
    ## Remove duplicates
    relevant_ids = list(set(relevant_ids))
    #print("These appear in "+str(len(relevant_ids))+" papers in the corpus.")

    ## Search for appearances of found SNOMED ids in our annotation set
    if len(relevant_ids)>0:
        #print("\tFinding papers annotated with relevant IDs")
        r = relevant_ids[0]
        relevant_papers = id_dict[r]
        for i in range(1,len(relevant_ids)):
            r = relevant_ids[i]
            relevant_papers = relevant_papers + id_dict[r]
            
        print("\t"+str(len(list(set(relevant_papers))))+" papers contain at least one of the above terms, based on the keyword: "+search_term)
    else:
        relevant_papers=[]
        print('\tNo relevant papers')
    return list(set(relevant_papers))


def get_similar_texts(input_texts,target_embeddings,bc):

    #print('input texts type : '+str(type(input_texts)))
    #input_embeddings = [bc.encode(i) for i in input_texts]
    input_embeddings = bc.encode(input_texts)
    #input_embeddings = bc.predict(input_texts)
    similarities = np.zeros((len(input_embeddings),len(target_embeddings)))

    for i in range(len(input_embeddings)):
    	similarities[i] = np.array([cosine_distance(input_embeddings[i],t) for t in target_embeddings])

    cumulative_similarities = np.sum(similarities,axis=0)
#     print(cumulative_similarities.shape)
    #print(cumulative_similarities)

    return np.argsort(-1*cumulative_similarities)

In [None]:
def_query = "What do we know about vaccines and therapeutics? What has been published concerning research and development and evaluation efforts of vaccines and therapeutics?"

@interact_manual
def search_articles(query=def_query, 
                    filter_dataset = [False, True],
                    use_synonyms = [False, True],
                    count = widgets.IntSlider(min=0, max=6, step=1, value=4)):
    answer_query(query, bc, filter_dataset, use_synonyms, count)