In [1]:

import torch
from transformers import BertTokenizer, RobertaTokenizer, AutoTokenizer
import sys
import os
from tqdm import tqdm
sys.path.append("/home/jovyan/20230406_ArticleClassifier/ArticleClassifier")

import src.general.global_variables as gv
sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname('data_loader.py'), os.path.pardir)))
from src.data.data_loader import DataLoader

from src.general.utils import cc_path

import pandas as pd
import numpy as np

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


In [None]:
model_version = 'scibert_scivocab_uncased'
do_lower_case = True
# model = BertModel.from_pretrained(model_version)
# model = torch.load(cc_path(f'models/embedders/finetuned_bert_56k_20e_3lay_best_iter.pt'), map_location=device)
# model = torch.load(cc_path(f'models/baselines/paula_finetuned_bert_56k_10e_tka.pt'))
# model = torch.load(cc_path(f'models/embedders/litcovid_finetuned_bert_56k_20e_3lay_best_iter.pt'))
# model = torch.load(cc_path(f'models/embedders/litcovid_finetuned_biobert_56k_20e_3lay_best_iter.pt'))
# model = torch.load(cc_path(f'models/embedders/litcovid_finetuned_bert_56k_20e_3lay_best_iter_meta.pt'))
model = torch.load(cc_path(f'models/embedders/litcovid_pretrained_best_iter_meta_stopwords.pt'))


# model = model.base_model

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', do_lower_case=do_lower_case)
# tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1', do_lower_case=do_lower_case)

In [3]:
from sklearn.metrics.pairwise import cosine_similarity
def embed_text(text, model):
    # print(text)
    encoded_text = tokenizer.encode(text, max_length=512, truncation=True)
    input_ids = torch.tensor(encoded_text).unsqueeze(0).to(device)  # Batch size 1
    outputs = model(input_ids)
    last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
    return last_hidden_states 

def get_similarity(em, em2):
    return cosine_similarity(em.detach().cpu().numpy(), em2.detach().cpu().numpy())

In [4]:
# We will use a mean of all word embeddings. To do that we will take mean over dimension 1 which is the sequence length.
coronavirus_em = embed_text(["Coronavirus"]*50000, model).mean(1)
mers_em = embed_text("Middle East Respiratory Virus", model).mean(1)
flu_em = embed_text("Flu", model).mean(1)
bog_em = embed_text("Bog", model).mean(1)
covid_2019 = embed_text("COVID-2019", model).mean(1)

print("Similarity for Coronavirus and Flu:" + str(get_similarity(coronavirus_em, flu_em)))
print("Similarity for Coronavirus and MERs:" + str(get_similarity(coronavirus_em, mers_em)))
print("Similarity for Coronavirus and COVID-2019:" + str(get_similarity(coronavirus_em, covid_2019)))
print("Similarity for Coronavirus and Bog:" + str(get_similarity(coronavirus_em, bog_em)))

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]

In [5]:
def scibert_embedding(emb_dat, embedding_dim):
    """Create the SciBERT embedding"""

    print('Initiating DataFrame for saving embedding...')
    embedding_cols = [f'd{i}' for i in range(embedding_dim)]
    embedded_df = pd.DataFrame(columns=['pui'] + embedding_cols)
    embedded_df['pui'] = emb_dat.loc[:, 'pui']
    embedded_df.set_index('pui', inplace=True)

    emb_dat.set_index('pui', inplace=True)

    # create embeddings
    print('Creating embeddings for all documents...')
    for idx, sentence in tqdm(emb_dat.iterrows(), total=len(emb_dat)):
        embedded_df.loc[idx] = embed_text(sentence['title'] + sentence['journal'] + " " + sentence['pub_type'].replace(';', ' ') + ' ' + ' '.join(sentence['keywords'])  + ' ' + sentence['abstract'], model).mean(1).detach().cpu().numpy()

    embedded_df.reset_index(names='pui', inplace=True)

    return embedded_df

In [6]:
loc_dict = {
        'processed_csv': cc_path('data/processed/litcovid/litcovid_articles_cleaned.csv')
    }
data_loader = DataLoader(loc_dict)
processed_df = data_loader.load_processed_csv()
label_columns = processed_df.loc[:, ~processed_df.columns.isin(
    ['file_name', 'title', 'keywords', 'abstract', 'abstract_2', 'authors', 'organization', 'chemicals',
     'num_refs', 'date-delivered', 'labels_m', 'labels_a', 'journal', 'pub_type', 'doi', 'label', 'label_m', 'list_label'])]
print(label_columns)

label_columns = label_columns.astype(int)

embedding_type = 'bert'

data_for_embedding = processed_df.dropna(subset=['abstract'])
data_for_embedding.loc[:, 'labels_m'] = data_for_embedding.loc[:, 'labels_m'].fillna('')
# data_for_embedding.loc[:, 'list_label'] = data_for_embedding.loc[:, 'labels_m'].str.split(',')

embedding_dim = 768
embedded_df = scibert_embedding(data_for_embedding, embedding_dim)

embedded_df.to_csv(cc_path(f'data/processed/litcovid/litcovid_embeddings_{embedding_type}_finetuned_20230529_meta_stopwords.csv'), index=False)

            pui  Case Report  Diagnosis  Epidemic Forecasting  Mechanism   
0      32519164            0          0                     0          1  \
1      32691006            0          0                     0          0   
2      32858315            1          0                     0          0   
3      32985329            0          0                     0          0   
4      32812051            0          0                     0          0   
...         ...          ...        ...                   ...        ...   
33694  34291555            0          0                     0          1   
33695  34204119            0          1                     0          0   
33696  34228511            0          1                     0          0   
33697  34226900            0          1                     0          0   
33698  34218774            0          0                     0          0   

       Prevention  Transmission  Treatment  
0               0             0          1

100%|██████████| 33671/33671 [09:56<00:00, 56.44it/s]


In [9]:
print(embedded_df)

            pui        d0        d1        d2        d3        d4        d5   
0      32519164  0.921391 -0.365566  0.187476  0.501266 -0.146242 -0.470866  \
1      32519164  0.921391 -0.365566  0.187476  0.501266 -0.146242 -0.470866   
2      32691006  1.149948  0.035349  0.229485  1.761266 -0.178359   -0.1766   
3      32691006  1.149948  0.035349  0.229485  1.761266 -0.178359   -0.1766   
4      32858315  0.363713 -0.870381 -0.068459  1.121618 -0.855506 -0.215016   
...         ...       ...       ...       ...       ...       ...       ...   
45219  34291555  1.062982 -0.053552 -0.007932  1.004114 -0.085894 -0.632584   
45220  34204119 -0.016456 -0.113315 -0.568399  1.168578 -0.207969 -1.403214   
45221  34228511 -0.141193 -0.292873 -0.882129  1.261617 -0.577297 -0.684213   
45222  34226900  0.847777 -0.083107  0.614294  1.733181  0.005081 -0.661331   
45223  34218774  0.913131  0.034869  0.467554  1.703317 -0.551943 -0.322953   

             d6        d7        d8  ...      d758 