In [1]:
from transformers import BertTokenizer, BertModel
import pandas as pd
import numpy as np
from scipy.spatial.distance import cosine
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import re
import torch
import numba
import math
from numba import jit, cuda

  from .autonotebook import tqdm as notebook_tqdm


**Download Model**

In [2]:
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True,)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


**Helper Functions**

In [3]:
def bert_text_preparation(text, tokenizer, clean=True):
    """Preparing the input for BERT
    
    Takes a string argument and performs
    pre-processing like adding special tokens,
    tokenization, tokens to ids, and tokens to
    segment ids. All tokens are mapped to seg-
    ment id = 1.
    
    Args:
        text (str): Text to be converted
        tokenizer (obj): Tokenizer object
            to convert text into BERT-re-
            adable tokens and ids
        
    Returns:
        list: List of BERT-readable tokens
        obj: Torch tensor with token ids
        obj: Torch tensor segment ids
    """
    marked_text = "[CLS] " + text + " [SEP]"
    if clean:
        tokenized_text = marked_text.split(" ")
    else:
        tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1]*len(indexed_tokens)

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    return tokenized_text, tokens_tensor, segments_tensors

def get_bert_embeddings(tokens_tensor, segments_tensors, model):
    """Get embeddings from an embedding model
    
    Args:
        tokens_tensor (obj): Torch tensor size [n_tokens]
            with token ids for each token in text
        segments_tensors (obj): Torch tensor size [n_tokens]
            with segment ids for each token in text
        model (obj): Embedding model to generate embeddings
            from token and segment ids
    
    Returns:
        list: List of list of floats of size
            [n_tokens, n_embedding_dimensions]
            containing embeddings for each token
    """
    
    # Gradient calculation id disabled
    # Model is in inference mode
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
        # Removing the first hidden state
        # The first state is the input state
        hidden_states = outputs[2][1:]

    # Getting embeddings from the final BERT layer
    token_embeddings = hidden_states[-1]
    # Collapsing the tensor into 1-dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=0)
    # Converting torchtensors to lists
    list_token_embeddings = [token_embed.tolist() for token_embed in token_embeddings]

    return list_token_embeddings

**Split the Dataset in 512 Chunks & Tokenize**

In [39]:
virus_type = ["COVID-19", "sars", "MERS", "Ebola"]
virus_study = ["origin", "evolution", "symptom", "examination"]
age_group = ["infant", "adult", "elderly"]

#all_topics = ["covid_19", "sars", "mers", "ebola", "origin", "evolution", "symptom", "examination", "infant", "adult", "elderly"]
all_topics = ["sars"]
topics_added = False

In [None]:
def generate_chunks(corpus, size=512, limit=None): #TODO: CHANGE TO GENERATOR and might need to introduce overlap
    # corpus is an array of strings where each item represents a document
    for doc in tqdm(corpus):
        if num == limit:
            break
        sents = doc.split(".")
        for s in sents:
            marked_s = "[CLS] " + s + " [SEP]"
            split_sent = marked_s.split(" ")
        
        ret_txt = np.array_split(split_sent, math.ceil(tokenized_text.shape[0]/size))
        if not topics_added:
            token_topics = [np.array(["[CLS]", w, "[SEP]"], dtype=str) for w in all_topics]
            ret_txt[:0] = token_topics
            topics_added = True
        
    
    return np.array(ret_txt, dtype=np.ndarray)

In [40]:
def chunkize(corpus, size=512, limit=None): #TODO: CHANGE TO GENERATOR and might need to introduce overlap
    tokenized_text = None
    num = 0
    # corpus is an array of strings where each item represents a document
    for doc in tqdm(corpus):
        if num == limit:
            break
        sents = doc.split(".")
        for s in sents:
            marked_s = "[CLS] " + s + " [SEP]"
            split_sent = marked_s.split(" ")
            if tokenized_text is None:
                tokenized_text = split_sent
            else:
                tokenized_text = np.append(tokenized_text, split_sent)
            
        num += 1
    print(tokenized_text.shape)
    ret_txt = np.array_split(tokenized_text, math.ceil(tokenized_text.shape[0]/size))
    token_topics = [np.array(["[CLS]", w, "[SEP]"], dtype=str) for w in all_topics]
    ret_txt[:0] = token_topics
    
    return np.array(ret_txt, dtype=np.ndarray)

def tokenize(chunk):        
    indexed_tokens = tokenizer.convert_tokens_to_ids(chunk)
    segments_ids = [1]*len(indexed_tokens)

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    return tokens_tensor, segments_tensors

**Read in Corpus & Define Topics**

In [6]:
with open('/shared/data2/pk36/multidim/covid_phrase_text.txt') as f:
    docs = f.readlines()

In [41]:
chunks = chunkize(docs, limit=100)
print(chunks.shape)

  0%|                                    | 100/29500 [02:20<11:26:03,  1.40s/it]

(286384,)
(561,)





**Compute Embeddings for All Words in Each Chunk**

In [42]:
word_embeddings = None # shape: (# words, 768)
all_words = []
add_topics = True

for chunk in tqdm(chunks):
    tokens_tensor, segments_tensors = tokenize(chunk)
    list_token_embeddings = np.array(get_bert_embeddings(tokens_tensor, segments_tensors, model)) # shape: (512 tokens in chunk, 768)
    all_words.extend(chunk)

    if word_embeddings is None:
        word_embeddings = list_token_embeddings
    else:
        word_embeddings = np.append(word_embeddings, list_token_embeddings, axis=0)

100%|█████████████████████████████████████████| 561/561 [04:53<00:00,  1.91it/s]


In [32]:
len(all_words)

286387

In [33]:
word_embeddings.shape

(286387, 768)

**Cosine-Similiarity**

_Word-to-Word:_

In [34]:
topic_col = []
word_col = []
cos_col = []

for i in tqdm(np.arange(len(all_topics))):
    topic_col.extend(np.repeat(all_topics[i], len(all_words)))
    word_col.extend(all_words)
    cos_col.extend(cosine_similarity(word_embeddings[i].reshape(1, -1), word_embeddings).reshape(-1, 1))

    
topic_col = np.array(topic_col)
word_col = np.array(word_col)
cos_col = np.array(cos_col).reshape((-1, ))

data = np.array([topic_col, word_col, cos_col]).T
cosine_df = pd.DataFrame(data, columns=['topic', 'word', 'cosine'])

100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


In [35]:
len(topic_col), len(word_col), len(cos_col)

(286387, 286387, 286387)

In [36]:
cosine_df

Unnamed: 0,topic,word,cosine
0,symptom,[CLS],1.0000000000000009
1,symptom,symptom,0.2028250618645076
2,symptom,[SEP],-0.10594274891158682
3,symptom,[CLS],0.4818992172118389
4,symptom,angiotensin_converting_enzyme,0.21770178984822242
...,...,...,...
286382,symptom,,0.11287630757556942
286383,symptom,[SEP],-0.07515445695102974
286384,symptom,[CLS],0.5731043226562862
286385,symptom,\n,0.08848080349404483


In [37]:
cosine_df["cosine"] = pd.to_numeric(cosine_df["cosine"])
word_df = cosine_df[(cosine_df.word != "[CLS]") 
                    & (cosine_df.word != "[SEP]")
                    & (cosine_df.word.str.len() > 3)
                    & ~(cosine_df.word.isin(all_topics))].sort_values(by=['cosine'], ascending=False)
word_df = word_df[word_df.word.str.isalnum()]
word_df.head(10)

Unnamed: 0,topic,word,cosine
115202,symptom,role,0.714869
102915,symptom,noted,0.710099
104451,symptom,virus,0.710076
71683,symptom,guarding,0.683049
274634,symptom,fixed,0.661575
204116,symptom,code,0.65788
97283,symptom,risk,0.657662
54275,symptom,virus,0.657354
87044,symptom,better,0.656019
151483,symptom,reported,0.650768


In [38]:
word_df.head(30)

Unnamed: 0,topic,word,cosine
115202,symptom,role,0.714869
102915,symptom,noted,0.710099
104451,symptom,virus,0.710076
71683,symptom,guarding,0.683049
274634,symptom,fixed,0.661575
204116,symptom,code,0.65788
97283,symptom,risk,0.657662
54275,symptom,virus,0.657354
87044,symptom,better,0.656019
151483,symptom,reported,0.650768


In [25]:
word_df[(word_df.topic == "ebola")]

Unnamed: 0,topic,word,cosine
974483,ebola,role,0.714869
962196,ebola,noted,0.710099
963732,ebola,virus,0.710076
930964,ebola,guarding,0.683049
1133915,ebola,fixed,0.661575
...,...,...,...
1078429,ebola,other,-0.140917
1116801,ebola,hand,-0.140965
970936,ebola,other,-0.146952
1119358,ebola,case,-0.149871


**Scratch Work (aka ignore)**

In [47]:
target_word_embeddings = []

for i in np.arange(0, len(text[0]), 512):
    tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(text[0][i:min(i+512, len(text[0]))], tokenizer)
    list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)

    word = "covid_19" # virus_type[1]

    # Find the position 'bank' in list of tokens
    if word not in tokenized_text:
        if i == 0:
            print(tokenized_text)
        print(i, ": not found")
    else:
        print(i, ": found")
        word_index = tokenized_text.index(word)
        # Get the embedding for bank
        word_embedding = list_token_embeddings[word_index]
        target_word_embeddings.append(word_embedding)

['[CLS]', 'ang', '##iot', '##ens', '##in', '_', 'converting', '_', 'enzyme', '2', '(', 'ace', '##2', ')', 'as', 'a', 'sar', '##s', '_', 'co', '##v', '_', '2', 'receptor', 'molecular', '_', 'mechanisms', 'and', 'potential', 'therapeutic', '_', 'target', 'sar', '##s', '_', 'co', '##v', '_', '2', 'has', 'been', 'sequence', '##d', '3', '.', 'a', 'phylogenetic', 'analysis', '3', ',', '4', 'found', 'a', 'bat', 'origin', 'for', 'the', 'sar', '##s', '_', 'co', '##v', '_', '2', '.', 'there', 'is', 'a', 'diversity', 'of', 'possible', 'intermediate', '_', 'hosts', 'for', 'sar', '##s', '_', 'co', '##v', '_', '2', ',', 'including', 'pang', '##olin', '##s', ',', 'but', 'not', 'mice', 'and', 'rats', '5', '.', 'there', 'are', 'many', 'similarities', 'of', 'sar', '##s', '_', 'co', '##v', '_', '2', 'with', 'the', 'original', 'sar', '##s', 'co', '##v', '.', 'using', 'computer', 'modeling', ',', 'xu', 'et', 'al', '.', '6', 'found', 'that', 'the', 'spike', '_', 'proteins', 'of', 'sar', '##s', '_', 'co', '#

In [28]:
word_embeddings = None # shape: (# words, 768)
all_words = []
add_topics = True

for doc in tqdm(docs):
    split_text = doc.split(".") # split corpus into list of sentences
    if add_topics:
        split_text[:0] = all_topics
        add_topics = False

    for sentence in split_text:
        tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(sentence, tokenizer)
        list_token_embeddings = np.array(get_bert_embeddings(tokens_tensor, segments_tensors, model))[1:-1] # shape: (# words in sentence, 768)
        all_words.extend(tokenized_text[1:-1])

        if word_embeddings is None:
            word_embeddings = list_token_embeddings
        else:
            word_embeddings = np.append(word_embeddings, list_token_embeddings, axis=0)
    
    # word_embeddings.append(list_token_embeddings[1:-1]) # ignore CLS and SEP tokens

  5%|█▍                              | 1328/29500 [5:54:55<125:29:12, 16.04s/it]


KeyboardInterrupt: 

In [10]:
# Calculating the distance between the embeddings of the word in all the given contexts of the word

list_of_distances = []
for text1, embed1 in zip(split_text, target_word_embeddings):
    for text2, embed2 in zip(split_text, target_word_embeddings):
        cos_dist = 1 - cosine(embed1, embed2)
        list_of_distances.append([text1, text2, cos_dist])

distances_df = pd.DataFrame(list_of_distances, columns=['text1', 'text2', 'distance'])

In [11]:
word_df = distances_df[distances_df.text1 == word].sort_values(by=['distance'], ascending=False)
word_df.head(10)

Unnamed: 0,text1,text2,distance
0,sars,sars,1.0
1,sars,angiotensin_converting_enzyme 2 ( ace2 ) as a ...,0.615406
3,sars,there is a diversity of possible intermediate...,0.602366
9,sars,"5 identity in amino_acid sequences 6 and , im...",0.547345
14,sars,this similarity with sars cov is critical bec...,0.53167
15,sars,it is required for host_cell entry and subseq...,0.527674
10,sars,wan et al,0.520226
11,sars,4 reported that residue 394 ( glutamine ) in ...,0.51776
13,sars,"thus , the sars_cov_2 spike_protein was predi...",0.512084
12,sars,further analysis even suggested that sars_cov...,0.509995


In [14]:
word_df["text2"][2]

' a phylogenetic analysis 3 , 4 found a bat origin for the sars_cov_2 '