In [1]:
import pandas as pd
import torch
from src.preprocessing import preprocess_df, random_train_test_split, TextEncoder
from src.embeddings import get_embeddings

In [2]:
### Constants
FILE = 'data/morning_lab_values.csv' # Set path to the dataset
COLUMNS = ['Bic', 'Crt', 'Pot', 'Sod', 'Ure', 'Hgb', 'Plt', 'Wbc']
BINS = 10

REPEAT_ID = True # If repetition_id is True, <<lab_id>> <<lab_id>><<lab_value_str>> else: <<lab_id>><<lab_value_str
USE_LAB_ID = True # If lab_id is True, <<lab_id>><<lab_value_str>> else: <<lab_value_str>>

# Link for the models: https://huggingface.co/dsrestrepo
# 1. "dsrestrepo/BERT_Lab_Values_10B_no_lab_id_no_repetition" (no repetition_id, no lab_id) -> Set REPEAT_ID = False, USE_LAB_ID = False
#MODEL = 'dsrestrepo/BERT_Lab_Values_10B_no_lab_id_no_repetition'
# 2. "dsrestrepo/BERT_Lab_Values_10B_lab_id_no_repetition" (no repetition_id, lab_id) -> Set REPEAT_ID = False, USE_LAB_ID = True
#MODEL = 'dsrestrepo/BERT_Lab_Values_10B_lab_id_no_repetition'
# 3. "dsrestrepo/BERT_Lab_Values_10B_lab_id_repetition" (repetition_id, lab_id) -> Set REPEAT_ID = True, USE_LAB_ID = True
#MODEL = 'dsrestrepo/BERT_Lab_Values_10B_lab_id_repetition'

MODEL = "dsrestrepo/BERT_Lab_Values_10B_lab_id_repetition" # There are 3 models available: "dsrestrepo/BERT_Lab_Values_10B_no_lab_id_no_repetition", "dsrestrepo/BERT_Lab_Values_10B_lab_id_no_repetition", "dsrestrepo/BERT_Lab_Values_10B_lab_id_repetition"


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('mps') if torch.cuda.is_available() else torch.device('cpu') # Apple slilicon

### Read the dataframe

In [3]:
df = pd.read_csv(FILE)
df.head()

Unnamed: 0,hadm_id,subject_id,itemid,charttime,charthour,storetime,storehour,chartday,valuenum,cnt
0,,10312413,51222,2173-06-05 08:20:00,8,2173-06-05 08:47:00,8,2173-06-05,12.8,8
1,25669789.0,10390828,51222,2181-10-26 07:55:00,7,2181-10-26 08:46:00,8,2181-10-26,9.4,8
2,26646522.0,10447634,51222,2165-03-07 06:55:00,6,2165-03-07 07:23:00,7,2165-03-07,11.1,8
3,27308928.0,10784877,51222,2170-05-11 06:00:00,6,2170-05-11 06:43:00,6,2170-05-11,10.3,8
4,28740988.0,11298819,51222,2142-09-13 07:15:00,7,2142-09-13 09:23:00,9,2142-09-13,10.2,8


### Preprocessing

In [4]:
mrl = preprocess_df(df, scaler='log', columns_to_scale=COLUMNS, num_bins=BINS)

In [5]:
text_encoder = TextEncoder(bins=BINS, Repetition_id=REPEAT_ID, lab_id=USE_LAB_ID)
mrl, grouped_mrl = text_encoder.encode_text(mrl)

In [6]:
# In this case mrl is the dataframe grouped by admission ID and grouped_mrl is the dataframe grouped by patiend ID
mrl.head()

itemid,subject_id,hadm_id,chartday,Bic,Crt,Pot,Sod,Ure,Hgb,Plt,Wbc,nstr
0,10000032,22595853.0,2180-05-07,7,0,7,3,6,8,0,1,Bic BicH Crt CrtA Pot PotH Sod SodD Ure UreG H...
1,10000032,22841357.0,2180-06-27,4,0,9,0,7,8,2,3,Bic BicE Crt CrtA Pot PotJ Sod SodA Ure UreH H...
2,10000032,25742920.0,2180-08-06,5,1,9,0,8,8,2,4,Bic BicF Crt CrtB Pot PotJ Sod SodA Ure UreI H...
3,10000032,25742920.0,2180-08-07,3,1,9,0,7,7,1,2,Bic BicD Crt CrtB Pot PotJ Sod SodA Ure UreH H...
4,10000032,29079034.0,2180-07-24,3,1,9,0,7,8,0,1,Bic BicD Crt CrtB Pot PotJ Sod SodA Ure UreH H...


In [7]:
# In this case mrl is the dataframe grouped by admission ID and grouped_mrl is the dataframe grouped by patiend ID
grouped_mrl.head()

Unnamed: 0,hadm_id,nstr
0,20000019.0,[Bic BicD Crt CrtE Pot PotA Sod SodD Ure UreF ...
1,20000024.0,[Bic BicE Crt CrtE Pot PotJ Sod SodG Ure UreH ...
2,20000034.0,[Bic BicD Crt CrtI Pot PotJ Sod SodH Ure UreH ...
3,20000041.0,[Bic BicF Crt CrtE Pot PotD Sod SodC Ure UreE ...
4,20000057.0,[Bic BicA Crt CrtE Pot PotG Sod SodD Ure UreF ...


### Generate the "sentences" of lab values

In [8]:
text = mrl['nstr'].tolist()#.apply(lambda x: ' '.join(x)).tolist()
train, test = random_train_test_split(text)

### Load Model from Hugging Face

In [9]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForMaskedLM.from_pretrained(MODEL)

### Predict with the model

In [12]:
# See some samples of the test set
subsample = test[:1000]
subsample

['Bic BicI Crt CrtE Pot PotJ Sod SodA Ure UreG Hgb HgbG Plt PltI Wbc WbcD',
 'Bic BicA Crt CrtB Pot PotA Sod SodE Ure UreB Hgb HgbA Plt PltI Wbc WbcJ',
 'Bic BicI Crt CrtE Pot PotJ Sod SodG Ure UreJ Hgb HgbC Plt PltA Wbc WbcD',
 'Bic BicE Crt CrtE Pot PotF Sod SodI Ure UreC Hgb HgbH Plt PltE Wbc WbcD',
 'Bic BicI Crt CrtC Pot PotH Sod SodF Ure UreI Hgb HgbE Plt PltC Wbc WbcI',
 'Bic BicG Crt CrtG Pot PotJ Sod SodE Ure UreJ Hgb HgbD Plt PltD Wbc WbcC',
 'Bic BicC Crt CrtF Pot PotE Sod SodA Ure UreF Hgb HgbC Plt PltI Wbc WbcF',
 'Bic BicD Crt CrtC Pot PotC Sod SodF Ure UreE Hgb HgbE Plt PltF Wbc WbcD',
 'Bic BicA Crt CrtC Pot PotI Sod SodG Ure UreC Hgb HgbJ Plt PltH Wbc WbcE',
 'Bic BicH Crt CrtG Pot PotJ Sod SodD Ure UreG Hgb HgbI Plt PltG Wbc WbcB',
 'Bic BicG Crt CrtH Pot PotF Sod SodC Ure UreI Hgb HgbD Plt PltH Wbc WbcD',
 'Bic BicA Crt CrtF Pot PotJ Sod SodH Ure UreH Hgb HgbB Plt PltD Wbc WbcE',
 'Bic BicD Crt CrtF Pot PotI Sod SodG Ure UreF Hgb HgbE Plt PltE Wbc WbcB',
 'Bic BicI C

In [13]:
# Let's run the embeddings prediction. To do so let's run the model and in each iteration we will mask the token in the position i and predict it.
def get_predictions(model, tokenizer, texts, batch_size=32, USE_LAB_ID=False, REPEAT_ID=False, print_outs=False):
    
    for i, lab_id in enumerate(COLUMNS):
        print(f'Predicting {lab_id}')
        subsample_copy = texts.copy()
        real_values = []
        indices = []
        for j, text in enumerate(texts):
            subsample_copy[j] = subsample_copy[j].split(' ')
            if USE_LAB_ID:
                look_for = lab_id
                # get index of the lab_id
                
                if REPEAT_ID:
                    index = subsample_copy[j].index(look_for)
                    real_values.append(subsample_copy[j][index+1])
                    subsample_copy[j][index+1] = tokenizer.mask_token
                    subsample_copy[j] = ' '.join(subsample_copy[j])
                    indices.append(index+2) # +2 because the initial token is the CLS token and the lab_id is repeated
                else:
                    # looking for index of something including the lab_id
                    index = [i for i, s in enumerate(subsample_copy[j]) if look_for in s][0]
                    real_values.append(subsample_copy[j][index])
                    subsample_copy[j][index] = tokenizer.mask_token
                    subsample_copy[j] = ' '.join(subsample_copy[j])
                    indices.append(index+1) # +1 because the initial token is the CLS token
            else:
                real_values.append(subsample_copy[j][i])
                subsample_copy[j][i] = tokenizer.mask_token
                subsample_copy[j] = ' '.join(subsample_copy[j])
                indices.append(i + 1) # +1 because the initial token is the CLS token
        
        
        # split the subsample into batches
        n_iterations = len(subsample_copy)//batch_size
        batches = []
        indices_batch = []
        for iteration in range(n_iterations):
            indices_batch.append(indices[iteration*batch_size:(iteration+1)*batch_size])
            batches.append(subsample_copy[iteration*batch_size:(iteration+1)*batch_size])
        if len(subsample_copy) % batch_size != 0:
            batches.append(subsample_copy[n_iterations*batch_size:])
            indices_batch.append(indices[n_iterations*batch_size:])
            
        predictions = []
        for count, batch in enumerate(batches):
            if count == 0:
                print(f'Predicting {batch}')
            if count % 10 == 0:
                print(f'Batch {count+1}/{len(batches)}')
            
            # Get the tokens
            tokens = tokenizer(batch, return_tensors='pt', padding=True)
            tokens.to(device)
            
            # Predict the masked token for the whole batch
            model.to(device)
            outputs = model(**tokens)
            
            # The predictions are in the logits, we need to get the argmax to get the predicted token
            predictions_batch = outputs.logits.argmax(-1)
            #print(f'Predictions: {predictions_batch}')
            
            # Get the predictions for the masked token (lab_id)
            predictions_lab = [predictions_batch[b, idx] for b, idx in enumerate(indices_batch[count])]
            #print(f'Predictions lab: {predictions_lab}')
            
            # Decode the predictions of the lab id
            predictions_batch = [tokenizer.decode(prediction.item()) for prediction in predictions_lab]
            
            if print_outs:
                print(f'Predictions: {predictions_batch}')
                print(f'Real values: {real_values[count*batch_size:(count+1)*batch_size]}')
                
            predictions.extend(predictions_batch)
            
        #Calculate the accuracy
        accuracy = sum([1 for real, pred in zip(real_values, predictions) if real == pred])/len(real_values)
        print(f'Accuracy: {accuracy}')
        

In [14]:
get_predictions(model, tokenizer, subsample, batch_size=32, USE_LAB_ID=USE_LAB_ID, REPEAT_ID=REPEAT_ID)

Predicting Bic
Predicting ['Bic [MASK] Crt CrtE Pot PotJ Sod SodA Ure UreG Hgb HgbG Plt PltI Wbc WbcD', 'Bic [MASK] Crt CrtB Pot PotA Sod SodE Ure UreB Hgb HgbA Plt PltI Wbc WbcJ', 'Bic [MASK] Crt CrtE Pot PotJ Sod SodG Ure UreJ Hgb HgbC Plt PltA Wbc WbcD', 'Bic [MASK] Crt CrtE Pot PotF Sod SodI Ure UreC Hgb HgbH Plt PltE Wbc WbcD', 'Bic [MASK] Crt CrtC Pot PotH Sod SodF Ure UreI Hgb HgbE Plt PltC Wbc WbcI', 'Bic [MASK] Crt CrtG Pot PotJ Sod SodE Ure UreJ Hgb HgbD Plt PltD Wbc WbcC', 'Bic [MASK] Crt CrtF Pot PotE Sod SodA Ure UreF Hgb HgbC Plt PltI Wbc WbcF', 'Bic [MASK] Crt CrtC Pot PotC Sod SodF Ure UreE Hgb HgbE Plt PltF Wbc WbcD', 'Bic [MASK] Crt CrtC Pot PotI Sod SodG Ure UreC Hgb HgbJ Plt PltH Wbc WbcE', 'Bic [MASK] Crt CrtG Pot PotJ Sod SodD Ure UreG Hgb HgbI Plt PltG Wbc WbcB', 'Bic [MASK] Crt CrtH Pot PotF Sod SodC Ure UreI Hgb HgbD Plt PltH Wbc WbcD', 'Bic [MASK] Crt CrtF Pot PotJ Sod SodH Ure UreH Hgb HgbB Plt PltD Wbc WbcE', 'Bic [MASK] Crt CrtF Pot PotI Sod SodG Ure UreF H