# SciBERT Embeddings Analysis
This is a basic tutorial and analysis of how to download and use BERT models to create naive embeddings, which can be used for exploring concepts in the COVID-19 literature corpus. [SciBERT](https://github.com/allenai/scibert) is a large transformer model trained on scientific text. While this tutorial shows how to use SciBERT, specifically you should be able to reuse the code for any HuggingFace transformer model. Using it allows us to explore several interesting questions:

    1. Are raw embeddings from the language model useful without any further fine-tuning?
    2. Can we construct a memory efficent semantic search using these embeddings?
    3. What does the embedding space look like when visualized? Are similar articles clustered together?

Of course long-term we would probably want to use richer embedding, specifically, embeddings from a model trained on this corpus or formed in a richer way. For instance, the embedding method displayed here is a simple MEAN over all the words in the text passage. Possibly more advanced methods would include using sentence transformers or document embedding techniques, but these methods (usually) require annotated training data. The nice thing about this method is it is relatively simple to use out of the box.

In [None]:
!pip install transformers
!wget -O scibert_uncased.tar https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/huggingface_pytorch/scibert_scivocab_uncased.tar
!tar -xvf scibert_uncased.tar
import torch
from transformers import BertTokenizer, BertModel

We will use the SciBERT Vocab uncased model as that is what is recommended on the official GitHub Page.

In [None]:
model_version = 'scibert_scivocab_uncased'
do_lower_case = True
model = BertModel.from_pretrained(model_version)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
def embed_text(text, model):
    input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)  # Batch size 1
    outputs = model(input_ids)
    last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
    return last_hidden_states 

def get_similarity(em, em2):
    return cosine_similarity(em.detach().numpy(), em2.detach().numpy())

In [None]:
# We will use a mean of all word embeddings. To do that we will take mean over dimension 1 which is the sequence length.
coronavirus_em = embed_text("Coronavirus", model).mean(1)
mers_em = embed_text("Middle East Respiratory Virus", model).mean(1)
flu_em = embed_text("Flu", model).mean(1)
bog_em = embed_text("Bog", model).mean(1)
covid_2019 = embed_text("COVID-2019", model).mean(1)
print("Similarity for Coronavirus and Flu:" + str(get_similarity(coronavirus_em, flu_em)))
print("Similarity for Coronavirus and MERs:" + str(get_similarity(coronavirus_em, mers_em)))
print("Similarity for Coronavirus and COVID-2019:" + str(get_similarity(coronavirus_em, covid_2019)))
print("Similarity for Coronavirus and Bog:" + str(get_similarity(coronavirus_em, bog_em)))



Anecdotally we can see that even in the raw embeddings there seems to be at least some correlation between concepts. However, it is curious why it scores things like Coronavirus and bog at .64 as I'd expect those to be dissimilar.

Let's now look at visualizing some of these vectors with [U-Map](https://towardsdatascience.com/how-exactly-umap-works-13e3040e1668). I'm choosing U-Map here due to the high-dimensionality of the data (768-D). However, I will also add some PCA visualizations below if I have time.

In [None]:
!pip install umap-learn
import umap
reducer = umap.UMAP()

In [None]:
import os
import json 
def make_the_embeds(number_files, start_range=0, 
                    the_path="/kaggle/input/CORD-19-research-challenge/comm_use_subset/comm_use_subset", data_key=["metadata", "title"]):
    the_list = os.listdir(the_path)
    title_embedding_list = [] 
    title_list = []
    for i in range(start_range, number_files):
        file_name = the_list[i]
        final_path = os.path.join(the_path, file_name)
        with open(final_path) as f:
            data = json.load(f)
        try:
            tensor, title = make_data_embedding(data, data_key)
            title_embedding_list.append(tensor)
            title_list.append(title)
        except:
            print("Invalid title/abstract")
    return torch.cat(title_embedding_list, dim=0), title_list
        
def make_data_embedding(article_data, data_keys, method="mean", dim=1):
    data = article_data
    for key in data_keys:
        data = data[key]
    text = embed_text(data, model)
    if method == "mean":
        return text.mean(dim), data
    
embed_list, title_list = make_the_embeds(200)
red = reducer.fit_transform(embed_list.detach().numpy())#

I found 200 to be a good chunk size for running quick analysis; doing a full plot can get kind of crowded and is slow to compute.

In [None]:
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Spectral10, Category20c
from bokeh.palettes import magma
import pandas as pd
output_notebook()

In [None]:
def make_plot(red, title_list, number=200, color = True, color_mapping_cat=None, color_cats = None, bg_color="white"):   
    digits_df = pd.DataFrame(red, columns=('x', 'y'))
    if color_mapping_cat:
        digits_df['colors'] = color_mapping_cat
    digits_df['digit'] = title_list
    datasource = ColumnDataSource(digits_df)
    plot_figure = figure(
    title='UMAP projection of the article title embeddings',
    plot_width=890,
    plot_height=600,
    tools=('pan, wheel_zoom, reset'),
    background_fill_color = bg_color
    )
    plot_figure.legend.location = "top_left",
    plot_figure.add_tools(HoverTool(tooltips="""
    <div>
    <div>
        <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
    </div>
    <div>
        <span style='font-size: 10px; color: #224499'></span>
        <span style='font-size: 10px'>@digit</span>
    </div>
    </div>
    """))
    if color:   
        color_mapping = CategoricalColorMapper(factors=title_list, palette=magma(number))
        plot_figure.circle(
            'x',
            'y',
            source=datasource,
            color=dict(field='digit', transform=color_mapping),
            line_alpha=0.6,
            fill_alpha=0.6,
            size=7
        )
        show(plot_figure)
    elif color_mapping_cat:
        color_mapping = CategoricalColorMapper(factors=color_cats, palette=magma(len(color_cats)+2)[2:])
        plot_figure.circle(
            'x',
            'y',
            source=datasource,
            color=dict(field='colors', transform=color_mapping),
            line_alpha=0.6,
            fill_alpha=0.6,
            size=8,
            legend_field='colors'
        )
        show(plot_figure)
    else:
        
        plot_figure.circle(
            'x',
            'y',
            source=datasource,
            color=dict(field='digit'),
            line_alpha=0.6,
            fill_alpha=0.6,
            size=7
        )
        show(plot_figure)
    
make_plot(red, title_list, number=200)

There do seem to be a few interesing patterns when analyizng with U-Map. However, I believe more sophisticated methods could definitely improve the clustering of groups. Let's examine another chunk:

In [None]:
embed_list2, title_list2 = make_the_embeds(401, 201)
red2 = reducer.fit_transform(embed_list.detach().numpy())
print(len(title_list2))
make_plot(red2, title_list2, number=198)

~~We'll attempt to make a plot of all 9000  1000 (that did make it run out of RAM)  articles in that directory (warning this might crash your notebook). For fun we'll make these a different 1000 then what we already viewed~~. Update: I'm going to comment this section out as it doesn't provide much value and requires a lot of RAM which hurts the downstream tasks. Feel free to run on a seperate fork if you want.

In [None]:
#max_len = len(os.listdir("/kaggle/input/CORD-19-research-challenge/2020-03-13/comm_use_subset/comm_use_subset"))
#embed_list, title_list_full = make_the_embeds(2000,1200)
#red_full = reducer.fit_transform(embed_list.detach().numpy())
#make_plot(red_full, title_list_full, 256, color=False)

**Visualizing with PCA**

PCA is one of the older embedding visualizations techniques. The advantage of PCA is that it is a bit faster than U-MAP and less RAM intensive. Therefore we can plot more results. For instance an average run took 14.2 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 100 loops each). However for now I'll also comment out this code as we will be using more RAM later on. We will be using PCA to plot our search attempt results though.

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2, svd_solver='full')
#embed_list_pca, title_list_pca = make_the_embeds(1000)
#result = pca.fit_transform(embed_list_pca.detach().numpy())

In [None]:
#make_plot(result, title_list_pca, 200)

## Part 2 Search Attempts on Titles

We are now going to build a very simple semantic search engine on the titles of the articles. This should in theory return the most similar articles for a given query. However, due to memory constraints I'm only going to run it on 200 titles instead of the whole corpus.

In [None]:
import collections
q1 = "COVID-19 infection origin and transmission from animals"
search_terms = embed_text(q1, model).mean(1)

In [None]:
def top_n_closest(search_term_embedding, title_embeddings, original_titles, n=10):
    proximity_dict = {}
    i = 0 
    for title_embedding in title_embeddings:
        proximity_dict[original_titles[i]] = {"score": get_similarity(title_embedding.unsqueeze(0),search_term_embedding), 
                                              "title_embedding":title_embedding.unsqueeze(0)}
        i+=1
    order_dict = collections.OrderedDict({k: v for k, v in sorted(proximity_dict.items(), key=lambda item: item[1]["score"])})
    proper_list = list(order_dict.keys())[-n:]
    return proper_list, order_dict
        

In [None]:
top_titles, order_dict = top_n_closest(search_terms, embed_list2, title_list+title_list2)

In [None]:
top_titles

The results actually don't seem that bad given the model doesn't have any specific training.

In [None]:
q2 = "coronavirus person to person transmission mechanisms"

In [None]:
search_terms2 = embed_text(q2, model).mean(1)
top_titles2, order_dict1 = top_n_closest(search_terms2, embed_list2, title_list + title_list2)
top_titles2

Another interesting thing we will try is to see if the returned search results occupy different places in the embedding space when plotting. In theory for a good search engine we probably would want very distinct clusters as these queries are different.

In [None]:
def remake_combine_dict_embeds_plot(titles_list, order_dicts, search_term_list):
    categories = []
    embeddings = [] 
    for i in range(0, len(order_dicts)):
        order_dict = order_dicts[i]
        titles = titles_list[i]
        for title in titles:
            embeddings.append(order_dict[title]["title_embedding"])
            categories.append(search_term_list[i])
    return embeddings, categories

embeds, cats = remake_combine_dict_embeds_plot([top_titles, top_titles2], [order_dict, order_dict1], [q1, q2] )

searches = [q1, q2]

In [None]:
title_list_full = top_titles + top_titles2
embeds2 = torch.cat(embeds, dim=0)
pca_res = pca.fit_transform(embeds2.detach().numpy())

In [None]:
make_plot(pca_res, title_list_full, 0, color=False, color_mapping_cat=cats, color_cats=searches, bg_color="black")

We can now see that there are definitely some overlapping titles (orange). Additionally the "coronavirus person to person transmission mechanism" query doesn't really seem to occupy a different embedding space than the "origin and transmission mechanism from animals" query. Keep in mind though this was only 400 titles due to the RAM constraints, not the entire corpus. If we queried the whole corpus I'd suspect we would find no overlapping titles. 

In [None]:
title_list_full

## Embedding Abstracts
Just for fun and to enrich our knowledge let's try embedding abstracts. Unfortunately we are extremly limited with RAM so the next section we won't be able to go beyond 20 or so abstracts at a time. 

In [None]:
absd_embeds, abs_orig = make_the_embeds(15, 0, data_key=['abstract', 0, "text"])

Since the following abstracts will be hard to display in U-Map with the toolover I won't plot them. Instead let's just look at these two:

In [None]:
abs_orig[0]

In [None]:
abs_orig[1]

In [None]:
get_similarity(absd_embeds[0].unsqueeze(0), absd_embeds[1].unsqueeze(0))

I honestly don't know enough about the subject area to tell if that is a good similarity score for those two.

## Finding Insights with abstracts
The end goal of this project is to find actionable insights about Coronavirus, its origins, potential treatment plans, and more.  For this step, due to the high overhead of embeddings methods we would probably want to combine our semantic search with the BM25 index. In this scenario BM25 would return 25 (or how much the RAM could handle) candidate answers and then we would rerank based on SciBERT embeddings.

In [None]:
search_terms = embed_text("coronavirus bat to human transmission", model).mean(1)
top_embeddings, order_dict = top_n_closest(search_terms, absd_embeds, abs_orig, n=3)

In [None]:
top_embeddings[0]

In [None]:
top_embeddings[1]

### Using Semantic Search + the BM25 index
We will now integrate our semantic search results with the BM25. Due to RAM being at a premimum we will first delete our previous objects from the search. Here the BM25 index will return the 25 most relevant results.

In [None]:
del red
del absd_embeds
del abs_orig
del red2
del embed_list
del embed_list2
del embeds 
import gc
gc.collect()

In [None]:
del Out[22]
del Out[26]
del Out[24]
del Out[6]
del Out[7]
gc.collect()

In [None]:
!pip install rank_bm25

In [None]:
# Let's download the BM25 index
import pickle
import os
!git clone https://github.com/CoronaWhy/CORD-19-QA 
os.chdir('CORD-19-QA')
!wget -O bert_bioasq_final-bm25.pkl https://publicweightsdata.s3.us-east-2.amazonaws.com/bert_bioasq_final-bm25.pkl
from bm25_index import BM25Index
the_index = pickle.load(open('bert_bioasq_final-bm25.pkl', 'rb'))

In [None]:
the_index = pickle.load(open('bert_bioasq_final-bm25.pkl', 'rb'))
abstracts = the_index.search("coronavirus bat to human transmission", 19)
abstract_embed = []
short_abstracts = []
abstracts.head()

Unfortunately we now also face an additional problem. BERT models cannot handle text passages of more than 512 *tokens*. For now since this is just a demo I'm just going to truncate long passages at 200 words. Why 200 words and not 512? Because the tokenizer actually returns longer sequence lengths than the original input. Since I'm lazy and don't want to modify my embed_text function I'm choosing 200 as an arbitrary length (that won't cross 512). There are a couple ways we could solve this in a real world setting:
1. Split the passage after 512 tokens. Feed in each 512 token chunk, then combine each resulting (768, 512) along the 512 dimension. For instance, if the first chunk had 512 tokens and the second had 300 we would end up with a (768, 812) tensor. We would then just take the average over the entire sequence length to get a (1, 768) tensor just like before.
2. We could attempt to use other models like Transformer-XL or Reformer. As these models do not have the 512 token constraint they could take the whole text passage. However, at the moment I don't know of any pre-trained versions of these models trained on scientific data.

In [None]:
def embed_abstracts_from_bm25(abstracts):
    abstracts = abstracts["abstract"].tolist()
    abstract_embed = []
    for abstract in abstracts:
        if len(tokenizer.encode(abstract))<512:
            abstract_embed.append(embed_text(abstract, model).mean(1).squeeze(0))
        else: 
            # TO-DO truncate to max allowable length
            abstract2 = " ".join(abstract.split()[:200])
            print(len(abstract2.split()))
            abstract_embed.append(embed_text(abstract2, model).mean(1).squeeze(0))
    return abstracts, abstract_embed

In [None]:
abstracts, abstract_embed = embed_abstracts_from_bm25(abstracts)
search_terms2 = embed_text("coronavirus bat to human transmission", model).mean(1)

In [None]:
top_abs1, order_dict1 = top_n_closest(search_terms2, abstract_embed, abstracts)

In [None]:
top_abs1[0]

In [None]:
top_abs1[1]

In [None]:
abstracts_2 = the_index.search("COVID-19 person to person transmission dynamics", 19)
search_terms3 = embed_text("COVID-19 person to person transmission dynamics", model).mean(1)
search_term_list = ["coronavirus bat to human transmission", "COVID-19 person to person transmission dynamics"]
abstracts2, abstract_embed2 = embed_abstracts_from_bm25(abstracts_2)

In [None]:
print(len(abstract_embed2))
top_abs, order_dict2 = top_n_closest(search_terms3, abstract_embed2, abstracts2)
abstract_lists = [top_abs1, abstracts2]
order_dicts = [order_dict1, order_dict2]
embeds, cats = remake_combine_dict_embeds_plot(abstract_lists, order_dicts, search_term_list)
embeds2 = torch.cat(embeds, dim=0)
pca_res = pca.fit_transform(embeds2.detach().numpy())

In [None]:
make_plot(pca_res, top_abs1+abstracts2, number=200, color = False, color_mapping_cat=cats, color_cats = cats, bg_color="white")

## Conclusion
We have a couple key takeaways (note I'll include the original question these bullet points reference in parantheses). 
* Out of the box SciBERT embeddings seem to capture meaning suprisingly well (1).
* Embeddings use a lot of RAM particularly when embedding abstracts. This makes an entirely semantic search impractical (2). 
* BERT cannot accomodate text passages longer than 512 characters which would make it difficult to embed an article's full-text.
* Visualized clusters of titles seemed to make sense, but could use evaluation by a biology/chemistry/drug expert (3)

## Next Steps
There are a bunch of possible next steps so I'm going to split them based on the area. 

** Evaluation **
* Get expert evaluation of title clusters to see if they actually make sense. 
* Figure out a way to plot abstracts in Bokeh in a visually appealing way. Have experts inspect abstract clusters.
* Continue to run semantic search on groups of returned BM25 results. Get experts to rank results. 


**Making better embeddings **

* Train a language model specifically on the COVID-19 corpus and evaluate its embeddings versus SciBERT
* Investigate sentence transformers (pre-trained on MedNLI/SciTail) and explore if these embeddings work better
* See if anyone has a scientifically trained transformer-XL model  or equivalent.

**Performance/Productionizing**

* Explore and benchmark speed/space constraints of returning larger numbers of BM25 results to re-rank with embeddings.
* Figure out storage method for embeddings and iteratively embed each article.
* Research if U-Map as a dimensionality reduction technique can return embeddings. Easier to load 100D into memory then 768D.

Hope you enjoyed this analysis. If you found it useful please upvote as it helps more people to see it. Also, feel free to ask any questions below.

