In [13]:

import torch
from transformers import BertTokenizer, BertModel
import sys
import os
from tqdm import tqdm
sys.path.append("/home/jovyan/20230406_ArticleClassifier/ArticleClassifier")

import src.general.global_variables as gv
sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname('data_loader.py'), os.path.pardir)))
from src.data.data_loader import DataLoader

from src.general.utils import cc_path

import pandas as pd
import numpy as np

In [7]:
model_version = 'scibert_scivocab_uncased'
do_lower_case = True
# model = BertModel.from_pretrained(model_version)
model = torch.load(cc_path(f'models/embedders/finetuned_bert_10k_20e_2lay.pt'))
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RobertaTokenizer'. 
The class this function is called from is 'BertTokenizer'.


In [8]:
from sklearn.metrics.pairwise import cosine_similarity
def embed_text(text, model):
    # print(text)
    encoded_text = tokenizer.encode(text, max_length=512, truncation=True)
    input_ids = torch.tensor(encoded_text).unsqueeze(0)  # Batch size 1
    outputs = model(input_ids)
    last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
    return last_hidden_states 

def get_similarity(em, em2):
    return cosine_similarity(em.detach().numpy(), em2.detach().numpy())

In [9]:
# We will use a mean of all word embeddings. To do that we will take mean over dimension 1 which is the sequence length.
coronavirus_em = embed_text(["Coronavirus"]*50000, model).mean(1)
mers_em = embed_text("Middle East Respiratory Virus", model).mean(1)
flu_em = embed_text("Flu", model).mean(1)
bog_em = embed_text("Bog", model).mean(1)
covid_2019 = embed_text("COVID-2019", model).mean(1)
print("Similarity for Coronavirus and Flu:" + str(get_similarity(coronavirus_em, flu_em)))
print("Similarity for Coronavirus and MERs:" + str(get_similarity(coronavirus_em, mers_em)))
print("Similarity for Coronavirus and COVID-2019:" + str(get_similarity(coronavirus_em, covid_2019)))
print("Similarity for Coronavirus and Bog:" + str(get_similarity(coronavirus_em, bog_em)))

Similarity for Coronavirus and Flu:[[0.5923033]]
Similarity for Coronavirus and MERs:[[0.55641955]]
Similarity for Coronavirus and COVID-2019:[[0.5912089]]
Similarity for Coronavirus and Bog:[[0.6268568]]


In [11]:
def scibert_embedding(emb_dat, embedding_dim):
        

    print('Initiating DataFrame for saving embedding...')
    embedding_cols = [f'd{i}' for i in range(embedding_dim)]
    embedded_df = pd.DataFrame(columns=['pui'] + embedding_cols)
    embedded_df['pui'] = emb_dat.loc[:, 'pui']
    embedded_df.set_index('pui', inplace=True)

    emb_dat.set_index('pui', inplace=True)
    print(embedded_df)
    # create embeddings
    print('Creating embeddings for all documents...')
    for idx, sentence in tqdm(emb_dat.iterrows(), total=len(emb_dat)):
        embedded_df.loc[idx] = embed_text(sentence['title'] + ' ' + ' '.join(sentence['keywords'])  + ' ' + sentence['abstract'], model).mean(1).detach().numpy()

    embedded_df.reset_index(names='pui', inplace=True)

    return embedded_df

In [14]:
loc_dict = {
        'processed_csv': cc_path('data/processed/canary/articles_cleaned.csv')
    }
data_loader = DataLoader(loc_dict)
processed_df = data_loader.load_processed_csv()
label_columns = processed_df.loc[:, ~processed_df.columns.isin(
    ['file_name', 'pui', 'title', 'keywords', 'abstract', 'abstract_2', 'authors', 'organization', 'chemicals',
     'num_refs', 'date-delivered', 'labels_m', 'labels_a'])]
label_columns = label_columns.astype(int)

embedding_type = 'scibert'

data_for_embedding = processed_df.dropna(subset=['abstract'])
data_for_embedding.loc[:, 'labels_m'] = data_for_embedding.loc[:, 'labels_m'].fillna('')
# data_for_embedding.loc[:, 'list_label'] = data_for_embedding.loc[:, 'labels_m'].str.split(',')

embedding_dim = 768
embedded_df = scibert_embedding(data_for_embedding, embedding_dim)


embedded_df.to_csv(cc_path(f'data/processed/canary/embeddings_{embedding_type}_finetuned_20230425.csv'), index=False)

Initiating DataFrame for saving embedding...
             d0   d1   d2   d3   d4   d5   d6   d7   d8   d9  ... d758 d759   
pui                                                           ...             
624531411   NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN  \
625340088   NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN   
625805682   NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN   
626662493   NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN   
626822402   NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN   
...         ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...   
2011621972  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN   
2011622024  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN   
2011622065  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN   
2011626864  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  ...  NaN  NaN   
2011632

  1%|          | 1028/117310 [02:01<3:49:52,  8.43it/s]


KeyboardInterrupt: 

In [10]:
print(embedded_df)

               pui        d0        d1        d2        d3        d4   
0        624531411   0.72099 -0.400038 -0.552218 -0.312934  0.073463  \
1        625340088  0.464988  0.499274 -0.497702    0.1075  0.027397   
2        625805682  0.467386  0.674454 -0.209557  0.117852 -0.217963   
3        626662493  0.472561 -0.147881 -0.394192  0.710131 -0.021184   
4        626822402  0.511049 -0.261134 -0.338247  0.283631 -0.385034   
...            ...       ...       ...       ...       ...       ...   
117305  2011621972  0.139334  0.152547 -0.112082  0.257826 -0.277364   
117306  2011622024  0.751879  -0.34464  -0.45699  0.444919 -0.024674   
117307  2011622065  0.039875 -0.182126 -0.041164  0.411543 -0.063595   
117308  2011626864  0.419988 -0.486235 -0.540463  0.171615 -0.112762   
117309  2011632199  0.442814 -0.096023 -0.229158  0.260837 -0.375718   

              d5        d6        d7        d8  ...      d758      d759   
0      -0.037645  0.411679  0.853509 -0.563616  ... -0.29486