In [1]:
import pandas as pd
import torch
from src.preprocessing import preprocess_df, random_train_test_split, TextEncoder
from src.embeddings import get_embeddings

In [2]:
### Constants
FILE = 'data/morning_lab_values.csv' # Set path to the dataset
COLUMNS = ['Bic', 'Crt', 'Pot', 'Sod', 'Ure', 'Hgb', 'Plt', 'Wbc']
BINS = 10

REPEAT_ID = True # If repetition_id is True, <<lab_id>> <<lab_id>><<lab_value_str>> else: <<lab_id>><<lab_value_str
USE_LAB_ID = True # If lab_id is True, <<lab_id>><<lab_value_str>> else: <<lab_value_str>>

# Link for the models: https://huggingface.co/dsrestrepo
# 1. "dsrestrepo/BERT_Lab_Values_10B_no_lab_id_no_repetition" (no repetition_id, no lab_id) -> Set REPEAT_ID = False, USE_LAB_ID = False
#MODEL = 'dsrestrepo/BERT_Lab_Values_10B_no_lab_id_no_repetition'
# 2. "dsrestrepo/BERT_Lab_Values_10B_lab_id_no_repetition" (no repetition_id, lab_id) -> Set REPEAT_ID = False, USE_LAB_ID = True
#MODEL = 'dsrestrepo/BERT_Lab_Values_10B_lab_id_no_repetition'
# 3. "dsrestrepo/BERT_Lab_Values_10B_lab_id_repetition" (repetition_id, lab_id) -> Set REPEAT_ID = True, USE_LAB_ID = True
#MODEL = 'dsrestrepo/BERT_Lab_Values_10B_lab_id_repetition'

MODEL = "dsrestrepo/BERT_Lab_Values_10B_lab_id_repetition" # There are 3 models available: "dsrestrepo/BERT_Lab_Values_10B_no_lab_id_no_repetition", "dsrestrepo/BERT_Lab_Values_10B_lab_id_no_repetition", "dsrestrepo/BERT_Lab_Values_10B_lab_id_repetition"


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('mps') if torch.cuda.is_available() else torch.device('cpu') # Apple slilicon

### Read the dataframe

In [3]:
df = pd.read_csv(FILE)
df.head()

Unnamed: 0,hadm_id,subject_id,itemid,charttime,charthour,storetime,storehour,chartday,valuenum,cnt
0,,10312413,51222,2173-06-05 08:20:00,8,2173-06-05 08:47:00,8,2173-06-05,12.8,8
1,25669789.0,10390828,51222,2181-10-26 07:55:00,7,2181-10-26 08:46:00,8,2181-10-26,9.4,8
2,26646522.0,10447634,51222,2165-03-07 06:55:00,6,2165-03-07 07:23:00,7,2165-03-07,11.1,8
3,27308928.0,10784877,51222,2170-05-11 06:00:00,6,2170-05-11 06:43:00,6,2170-05-11,10.3,8
4,28740988.0,11298819,51222,2142-09-13 07:15:00,7,2142-09-13 09:23:00,9,2142-09-13,10.2,8


### Preprocessing

In [4]:
mrl = preprocess_df(df, scaler='log', columns_to_scale=COLUMNS, num_bins=BINS)

In [5]:
text_encoder = TextEncoder(bins=BINS, Repetition_id=REPEAT_ID, lab_id=USE_LAB_ID)
mrl, grouped_mrl = text_encoder.encode_text(mrl)

In [6]:
# In this case mrl is the dataframe grouped by admission ID and grouped_mrl is the dataframe grouped by patiend ID
mrl.head()

itemid,subject_id,hadm_id,chartday,Bic,Crt,Pot,Sod,Ure,Hgb,Plt,Wbc,nstr
0,10000032,22595853.0,2180-05-07,7,0,7,3,6,8,0,1,Bic BicH Crt CrtA Pot PotH Sod SodD Ure UreG H...
1,10000032,22841357.0,2180-06-27,4,0,9,0,7,8,2,3,Bic BicE Crt CrtA Pot PotJ Sod SodA Ure UreH H...
2,10000032,25742920.0,2180-08-06,5,1,9,0,8,8,2,4,Bic BicF Crt CrtB Pot PotJ Sod SodA Ure UreI H...
3,10000032,25742920.0,2180-08-07,3,1,9,0,7,7,1,2,Bic BicD Crt CrtB Pot PotJ Sod SodA Ure UreH H...
4,10000032,29079034.0,2180-07-24,3,1,9,0,7,8,0,1,Bic BicD Crt CrtB Pot PotJ Sod SodA Ure UreH H...


In [7]:
# In this case mrl is the dataframe grouped by admission ID and grouped_mrl is the dataframe grouped by patiend ID
grouped_mrl.head()

Unnamed: 0,hadm_id,nstr
0,20000019.0,[Bic BicD Crt CrtE Pot PotA Sod SodD Ure UreF ...
1,20000024.0,[Bic BicE Crt CrtE Pot PotJ Sod SodG Ure UreH ...
2,20000034.0,[Bic BicD Crt CrtI Pot PotJ Sod SodH Ure UreH ...
3,20000041.0,[Bic BicF Crt CrtE Pot PotD Sod SodC Ure UreE ...
4,20000057.0,[Bic BicA Crt CrtE Pot PotG Sod SodD Ure UreF ...


### Generate the "sentences" of lab values

In [8]:
text = mrl['nstr'].tolist()#.apply(lambda x: ' '.join(x)).tolist()
train, test = random_train_test_split(text)

### Generate the embeddings

In [9]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForMaskedLM.from_pretrained(MODEL)

In [10]:
sample_test = test[:10]

embeddings = get_embeddings(model=model, tokenizer=tokenizer, texts=sample_test)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [11]:
# The shape of the resulting vector is (batch_size, sequence_length, embedding_size)
embeddings.shape

(10, 16, 768)