In [15]:
from transformers import BertTokenizer, BertModel
import pandas as pd
import numpy as np
from scipy.spatial.distance import cosine
import torch

**Download Model**

In [3]:
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True,)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


**Helper Functions**

In [4]:
def bert_text_preparation(text, tokenizer, clean=True):
    """Preparing the input for BERT
    
    Takes a string argument and performs
    pre-processing like adding special tokens,
    tokenization, tokens to ids, and tokens to
    segment ids. All tokens are mapped to seg-
    ment id = 1.
    
    Args:
        text (str): Text to be converted
        tokenizer (obj): Tokenizer object
            to convert text into BERT-re-
            adable tokens and ids
        
    Returns:
        list: List of BERT-readable tokens
        obj: Torch tensor with token ids
        obj: Torch tensor segment ids
    """
    marked_text = "[CLS] " + text + " [SEP]"
    if clean:
        tokenized_text = marked_text.split(" ")
    else:
        tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1]*len(indexed_tokens)

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    return tokenized_text, tokens_tensor, segments_tensors

def get_bert_embeddings(tokens_tensor, segments_tensors, model):
    """Get embeddings from an embedding model
    
    Args:
        tokens_tensor (obj): Torch tensor size [n_tokens]
            with token ids for each token in text
        segments_tensors (obj): Torch tensor size [n_tokens]
            with segment ids for each token in text
        model (obj): Embedding model to generate embeddings
            from token and segment ids
    
    Returns:
        list: List of list of floats of size
            [n_tokens, n_embedding_dimensions]
            containing embeddings for each token
    """
    
    # Gradient calculation id disabled
    # Model is in inference mode
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
        # Removing the first hidden state
        # The first state is the input state
        hidden_states = outputs[2][1:]

    # Getting embeddings from the final BERT layer
    token_embeddings = hidden_states[-1]
    # Collapsing the tensor into 1-dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=0)
    # Converting torchtensors to lists
    list_token_embeddings = [token_embed.tolist() for token_embed in token_embeddings]

    return list_token_embeddings

**Read in Corpus & Define Topics**

In [29]:
virus_type = ["COVID-19", "sars", "MERS", "Ebola"]
virus_study = ["origin", "evolution", "symptom", "examination"]
age_group = ["infant", "adult", "elderly"]

all_words = ["COVID-19", "sars", "MERS", "Ebola", "origin", "evolution", "symptom", "examination", "infant", "adult", "elderly"]

In [5]:
with open('/shared/data2/pk36/multidim/covid_phrase_text.txt') as f:
    text = f.readlines()
text = text[0]

split_text = text.split(".")

word = "sars"
split_text.insert(0, word)

**Compute Embeddings**

In [6]:
target_word_embeddings = []

for t in split_text:
    tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(t, tokenizer)
    list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)
    
    if word not in tokenized_text:
        continue;
    else:
        word_index = tokenized_text.index(word)
        # Get the embedding for bank
        word_embedding = list_token_embeddings[word_index]
        target_word_embeddings.append(word_embedding)

In [7]:
len(target_word_embeddings)

16

**Cosine-Similiarity**

In [10]:
# Calculating the distance between the embeddings of the word in all the given contexts of the word

list_of_distances = []
for text1, embed1 in zip(split_text, target_word_embeddings):
    for text2, embed2 in zip(split_text, target_word_embeddings):
        cos_dist = 1 - cosine(embed1, embed2)
        list_of_distances.append([text1, text2, cos_dist])

distances_df = pd.DataFrame(list_of_distances, columns=['text1', 'text2', 'distance'])

In [11]:
word_df = distances_df[distances_df.text1 == word].sort_values(by=['distance'], ascending=False)
word_df.head(10)

Unnamed: 0,text1,text2,distance
0,sars,sars,1.0
1,sars,angiotensin_converting_enzyme 2 ( ace2 ) as a ...,0.615406
3,sars,there is a diversity of possible intermediate...,0.602366
9,sars,"5 identity in amino_acid sequences 6 and , im...",0.547345
14,sars,this similarity with sars cov is critical bec...,0.53167
15,sars,it is required for host_cell entry and subseq...,0.527674
10,sars,wan et al,0.520226
11,sars,4 reported that residue 394 ( glutamine ) in ...,0.51776
13,sars,"thus , the sars_cov_2 spike_protein was predi...",0.512084
12,sars,further analysis even suggested that sars_cov...,0.509995


In [14]:
word_df["text2"][2]

' a phylogenetic analysis 3 , 4 found a bat origin for the sars_cov_2 '

**Scratch Work (aka ignore)**

In [47]:
target_word_embeddings = []

for i in np.arange(0, len(text[0]), 512):
    tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(text[0][i:min(i+512, len(text[0]))], tokenizer)
    list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)

    word = "covid_19" # virus_type[1]

    # Find the position 'bank' in list of tokens
    if word not in tokenized_text:
        if i == 0:
            print(tokenized_text)
        print(i, ": not found")
    else:
        print(i, ": found")
        word_index = tokenized_text.index(word)
        # Get the embedding for bank
        word_embedding = list_token_embeddings[word_index]
        target_word_embeddings.append(word_embedding)

['[CLS]', 'ang', '##iot', '##ens', '##in', '_', 'converting', '_', 'enzyme', '2', '(', 'ace', '##2', ')', 'as', 'a', 'sar', '##s', '_', 'co', '##v', '_', '2', 'receptor', 'molecular', '_', 'mechanisms', 'and', 'potential', 'therapeutic', '_', 'target', 'sar', '##s', '_', 'co', '##v', '_', '2', 'has', 'been', 'sequence', '##d', '3', '.', 'a', 'phylogenetic', 'analysis', '3', ',', '4', 'found', 'a', 'bat', 'origin', 'for', 'the', 'sar', '##s', '_', 'co', '##v', '_', '2', '.', 'there', 'is', 'a', 'diversity', 'of', 'possible', 'intermediate', '_', 'hosts', 'for', 'sar', '##s', '_', 'co', '##v', '_', '2', ',', 'including', 'pang', '##olin', '##s', ',', 'but', 'not', 'mice', 'and', 'rats', '5', '.', 'there', 'are', 'many', 'similarities', 'of', 'sar', '##s', '_', 'co', '##v', '_', '2', 'with', 'the', 'original', 'sar', '##s', 'co', '##v', '.', 'using', 'computer', 'modeling', ',', 'xu', 'et', 'al', '.', '6', 'found', 'that', 'the', 'spike', '_', 'proteins', 'of', 'sar', '##s', '_', 'co', '#