In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings")
model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings").to(device)

In [2]:
def replace_longest_text_with_halves(texts):
    if not texts:
        return texts

    max_length_index = max(range(len(texts)), key=lambda i: len(texts[i]))
    longest_text = texts[max_length_index]
    mid_index = len(longest_text) // 2
    first_half = longest_text[:mid_index]
    second_half = longest_text[mid_index:]
    texts[max_length_index:max_length_index + 1] = [first_half, second_half]
    return texts

def process_text_vectorization_with_pad(text_list):
    cls_list = []
    for i in tqdm(range(len(text_list))):
        sentences = text_list[i].split('. ')
        tokenized = tokenizer(sentences, add_special_tokens=True, padding = True, return_tensors="pt")
        tokenized = {k:torch.tensor(v).to(device) for k,v in tokenized.items()}

        if tokenized['input_ids'].size(dim=1) > 512:
            sentences_new = replace_longest_text_with_halves(sentences)
            tokenized_new = tokenizer(sentences_new, add_special_tokens=True, padding = True, return_tensors="pt")
            tokenized_new = {k:torch.tensor(v).to(device) for k,v in tokenized_new.items()}
            with torch.no_grad():
                hidden_state_new = model(**tokenized_new)
            cls_state_new = hidden_state_new.last_hidden_state[:,0,:]
            cls_state_new = cls_state_new.to('cpu')
            cls_list.append(torch.unsqueeze(torch.mean(cls_state_new, dim = 0), 0))

        else:
            with torch.no_grad():
                hidden_state = model(**tokenized)
            cls_state = hidden_state.last_hidden_state[:,0,:]
            cls_state = cls_state.to('cpu')
            cls_list.append(torch.unsqueeze(torch.mean(cls_state, dim = 0), 0))
    return cls_list

In [3]:
title_abstract_texts_X = pd.read_csv('title_abstract_texts_X.csv', sep=',')
title_abstract_texts_all = pd.read_csv('title_abstract_texts_all.csv', sep=',')

In [4]:
texts_X = list(title_abstract_texts_X['0'].values)
all_texts = list(title_abstract_texts_all['0'].values)

In [5]:
def add_space_after_period(text):
    return text.replace('.', '. ')

for i in tqdm(range(len(texts_X))):
    abst_sent = texts_X[i].split('. ')
    sent_len = []
    for k in range(len(abst_sent)):
        sent_len.append(len(abst_sent[k]))
    if max(sent_len) > 512:
        texts_X[i] = add_space_after_period(texts_X[i])

100%|██████████| 8676/8676 [00:00<00:00, 51603.10it/s]


In [6]:
text_vect_X = process_text_vectorization_with_pad(texts_X)

100%|██████████| 8676/8676 [03:52<00:00, 37.39it/s]


In [7]:
cls_X_array = []
for i in range(len(text_vect_X)):
    cls_X_array.append(text_vect_X[i][0].numpy())
cls_X_array = np.array(cls_X_array)
cls_X_matrix = pd.DataFrame(cls_X_array)
cls_X_matrix

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.373889,-0.703256,0.374886,-0.297805,0.161632,0.228160,0.130224,-0.150067,-0.343453,-0.595476,...,-0.005207,-0.349102,-0.575527,-0.305271,-0.070357,-0.467309,-0.244757,1.005329,0.556296,0.534453
1,-0.383222,0.652075,0.335900,-0.960231,-0.298706,0.095438,-0.667535,1.033348,0.512542,-0.711529,...,0.360285,-0.426473,-0.404499,-0.479120,-0.503823,-0.514091,0.053274,0.475853,-0.519811,0.537110
2,-0.065497,0.426198,-0.163131,-0.299998,0.326258,0.097311,-0.568513,0.555474,0.396552,-0.293481,...,0.031321,0.094037,-0.084511,0.204504,0.142914,-0.823774,-0.189775,0.010047,-0.442566,-0.506690
3,-0.112857,0.226508,0.653670,-0.457181,0.007463,0.048010,-0.127826,1.189659,-0.208718,-0.230247,...,0.125705,0.173145,-0.579066,-0.349503,-0.030507,-0.290230,0.397293,0.463619,-0.085654,0.014183
4,-0.287080,-0.491167,-0.577795,0.316255,0.068432,0.143247,-0.072142,0.535331,0.063482,0.084803,...,0.606081,-0.154836,-0.117076,0.051268,-0.528468,-0.494634,0.126287,0.305325,0.117943,-0.201062
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8671,-0.322825,0.751444,-0.186054,-0.641733,-0.122731,-0.273129,-0.680939,0.579574,0.076472,-0.274429,...,0.341475,0.056225,0.022454,0.692398,-0.028677,-0.281807,0.546542,-0.541992,-0.250782,-0.222607
8672,-0.569759,0.565158,-0.185162,-0.553001,-0.171846,-0.266699,-0.840665,0.427140,-0.185592,-0.450418,...,0.077524,0.221461,-0.054692,0.545349,0.134292,-0.341989,0.703267,-0.203889,-0.166344,-0.125695
8673,-0.280795,0.135812,-0.184267,-0.054235,-0.143007,-0.127818,-0.753047,0.245499,-0.023102,-0.160456,...,0.144803,-0.142981,0.437549,0.181531,-0.184472,-0.145505,0.320822,-0.343802,0.995426,-0.123898
8674,-0.361582,0.450111,-0.223125,0.106407,0.123348,-0.613265,-1.186897,0.381976,0.299640,-0.723795,...,-0.144107,-0.099030,0.461965,0.385374,0.127489,-0.420503,0.254312,-0.499778,0.929133,-0.336045


In [8]:
cls_X_matrix.to_csv('cls_X_matrix_pubmedbert.csv', index=False)

In [9]:
for i in tqdm(range(len(all_texts))):
    abst_sent = all_texts[i].split('. ')
    sent_len = []
    for k in range(len(abst_sent)):
        sent_len.append(len(abst_sent[k]))
    if max(sent_len) > 512:
        all_texts[i] = add_space_after_period(all_texts[i])

100%|██████████| 73154/73154 [00:00<00:00, 170848.05it/s]


In [10]:
vect_abs_all = process_text_vectorization_with_pad(all_texts)

100%|██████████| 73154/73154 [34:32<00:00, 35.30it/s]


In [11]:
cls_all_array = []
for i in range(len(vect_abs_all)):
    cls_all_array.append(vect_abs_all[i][0].numpy())
cls_all_array = np.array(cls_all_array)
cls_all_matrix = pd.DataFrame(cls_all_array)
cls_all_matrix

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,-0.130965,-0.260179,-0.268914,0.162671,-0.364543,-0.201208,-0.784305,0.795193,-0.097517,0.106002,...,-0.128758,-0.043944,0.312121,-0.541282,0.698506,0.166163,-0.070943,-0.686027,0.971027,0.363682
1,0.012001,0.548081,0.222156,-0.525367,0.429316,-0.172972,-1.517008,0.568559,1.075855,-0.221001,...,-0.682216,0.331362,0.206088,-0.963773,0.400253,-0.431421,0.308654,-0.849815,0.193067,-0.240808
2,-0.309866,0.417456,0.230139,-0.387211,0.090072,-0.248781,-0.748646,0.849247,0.785837,-0.057873,...,0.091029,0.154633,0.091867,-0.467128,0.197994,0.038066,0.340556,-0.582569,-0.083999,0.333746
3,-0.224888,0.214068,-0.113817,-0.202996,-0.382809,-0.343487,-0.103549,0.442259,0.055885,0.375157,...,0.192205,0.114896,0.552517,-0.751992,0.456497,0.230951,0.231220,-1.142182,0.206221,0.094950
4,0.067757,0.702381,-0.263730,-0.357449,0.195924,-0.477647,-0.636396,0.534445,0.076726,0.041163,...,-0.161210,0.046652,0.199827,-0.168400,0.413084,0.006566,-0.106204,-0.675783,-0.210757,0.244947
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
73149,-0.350987,-0.059362,0.194035,-0.993575,0.054872,0.075118,-0.956566,0.687091,0.522308,-0.001237,...,0.090175,0.206072,0.046186,0.369802,0.055501,-0.099584,0.218414,-0.306627,0.292805,0.077348
73150,-0.217519,0.593597,0.186192,-0.727252,0.359484,0.248115,-0.779221,0.399183,-0.004739,-0.956394,...,0.067595,-0.013809,-0.008955,0.447578,-0.127351,-0.886581,0.522533,-0.054824,0.445839,-0.463332
73151,0.113743,0.170187,0.187120,-0.615201,0.021587,-0.098802,-0.678996,0.568642,-0.138681,-0.585308,...,-0.055075,-0.106992,0.037750,0.461464,0.225559,-0.481519,-0.127510,-0.285708,0.239695,0.183276
73152,0.160824,-0.034283,0.071946,-0.816152,0.052463,-0.147809,-0.765057,0.658510,0.401252,0.027657,...,-0.462864,0.220455,-0.184843,-0.240730,-0.159212,-0.430714,0.017962,-0.307833,-0.211650,0.104236


In [12]:
cls_all_matrix.to_csv('cls_all_matrix_pubmedbert.csv', index=False)