### Labrador Embeddings

In [1]:
import pandas as pd
import torch
import numpy as np
import torch.nn as nn

from src.preprocessing import preprocess_df, TextEncoder, set_labels_features
from sklearn.model_selection import train_test_split

from src.labrador import Labrador
from src.tokenizers import LabradorTokenizer
from src.dataset import LabradorDataset

from src.train import train_labrador

### Constants

In [2]:
# Dataset:
FILE = 'data/morning_lab_values.csv'
COLUMNS = ['Bic', 'Crt', 'Pot', 'Sod', 'Ure', 'Hgb', 'Plt', 'Wbc']

# Device:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('mps') if torch.cuda.is_available() else torch.device('cpu') # Apple Silicon

# Data loader: 
test_size = 0.2
batch_size = 256
num_workers = 4
max_len = 10
MASKING = 0.20

# Model:
embedding_dim = 756
hidden_dim = 756
transformer_heads = 12
num_blocks = 12
transformer_feedforward_dim = 3072
dropout_rate = 0.3
continuous_head_activation = 'relu'

# Training:
optimizer = 'Adam'
num_epochs = 2
save_model = True
model_path = 'labrador_model.pth'
categorical_loss_weight = 1.0
continuous_loss_weight = 1.0


### Read dataset

In [3]:
df = pd.read_csv(FILE)
df.head()

Unnamed: 0,hadm_id,subject_id,itemid,charttime,charthour,storetime,storehour,chartday,valuenum,cnt
0,,10312413,51222,2173-06-05 08:20:00,8,2173-06-05 08:47:00,8,2173-06-05,12.8,8
1,25669789.0,10390828,51222,2181-10-26 07:55:00,7,2181-10-26 08:46:00,8,2181-10-26,9.4,8
2,26646522.0,10447634,51222,2165-03-07 06:55:00,6,2165-03-07 07:23:00,7,2165-03-07,11.1,8
3,27308928.0,10784877,51222,2170-05-11 06:00:00,6,2170-05-11 06:43:00,6,2170-05-11,10.3,8
4,28740988.0,11298819,51222,2142-09-13 07:15:00,7,2142-09-13 09:23:00,9,2142-09-13,10.2,8


### Preprocessing

In [4]:
# MinMaxScaler by default
mrl = preprocess_df(df, columns_to_scale=COLUMNS)

### Generate Sequences

In [None]:
# Generate the sequences:
text_encoder = TextEncoder(Repetition_id=True, labs_as_num=True, return_lists=True)
mrl, grouped_mrl = text_encoder.encode_text(mrl)

In [None]:
mrl.head(3)

itemid,subject_id,hadm_id,chartday,Bic,Crt,Pot,Sod,Ure,Hgb,Plt,Wbc,nstr,lab_ids,lab_values
0,10000032,22595853.0,2180-05-07,0.530612,0.007895,0.258621,0.609524,0.088028,0.585253,0.027731,0.004782,Bic 0.5306122448979591 Crt 0.00789473684210526...,"[Bic, Crt, Pot, Sod, Ure, Hgb, Plt, Wbc]","[0.5306122448979591, 0.007894736842105262, 0.2..."
1,10000032,22841357.0,2180-06-27,0.469388,0.007895,0.318966,0.504762,0.102113,0.571429,0.055462,0.007515,Bic 0.46938775510204084 Crt 0.0078947368421052...,"[Bic, Crt, Pot, Sod, Ure, Hgb, Plt, Wbc]","[0.46938775510204084, 0.007894736842105262, 0...."
2,10000032,25742920.0,2180-08-06,0.489796,0.015789,0.413793,0.504762,0.130282,0.557604,0.053782,0.008539,Bic 0.48979591836734687 Crt 0.0157894736842105...,"[Bic, Crt, Pot, Sod, Ure, Hgb, Plt, Wbc]","[0.48979591836734687, 0.015789473684210523, 0...."


In [None]:
# See the largest element in: 
grouped_mrl.head(3)

Unnamed: 0,hadm_id,nstr,lab_ids,lab_values
0,20000019.0,[Bic 0.4489795918367347 Crt 0.0289473684210526...,"[Bic, Crt, Pot, Sod, Ure, Hgb, Plt, Wbc]","[0.4489795918367347, 0.02894736842105263, 0.17..."
1,20000024.0,[Bic 0.46938775510204084 Crt 0.028947368421052...,"[Bic, Crt, Pot, Sod, Ure, Hgb, Plt, Wbc]","[0.46938775510204084, 0.02894736842105263, 0.3..."
2,20000034.0,[Bic 0.4489795918367347 Crt 0.0605263157894736...,"[Bic, Crt, Pot, Sod, Ure, Hgb, Plt, Wbc]","[0.4489795918367347, 0.06052631578947368, 0.28..."


#### Train and Test Split

In [None]:
lab_ids = grouped_mrl.lab_ids.values
lab_values = grouped_mrl.lab_values.values

lab_ids_train, lab_ids_test, lab_values_train, lab_values_test = train_test_split(lab_ids, lab_values, test_size=test_size, random_state=42)

#### Tokenize

In [None]:
tokenizer = LabradorTokenizer()
# Get unique lab ids:
unique_ids = set(np.concatenate(lab_ids_train))
# train the tokenizer:
tokenizer.train(unique_ids)

In [None]:
# Example of how to use the tokenizer:
tokenizer.tokenize_batch(lab_ids_train[:5], lab_values_train[:5], max_length=10)

{'input_ids': array([[ 2,  0,  7,  1,  3,  4,  5,  6, 10, 10],
        [ 2,  0,  7,  1,  3,  4,  5,  6, 10, 10],
        [ 2,  0,  7,  1,  3,  4,  5,  6, 10, 10],
        [ 2,  0,  7,  1,  3,  4,  5,  6, 10, 10],
        [ 2,  0,  7,  1,  3,  4,  5,  6, 10, 10]]),
 'continuous': array([[6.73469388e-01, 7.89473684e-02, 1.46551724e-01, 6.09523810e-01,
         3.80281690e-01, 6.31336406e-01, 4.15966387e-02, 6.94523511e-03,
         1.00000000e+01, 1.00000000e+01],
        [6.12244898e-01, 2.89473684e-02, 1.37931034e-01, 7.04761905e-01,
         9.85915493e-02, 3.50230415e-01, 7.85714286e-02, 1.04747808e-02,
         1.00000000e+01, 1.00000000e+01],
        [5.30612245e-01, 2.10526316e-02, 2.06896552e-01, 6.28571429e-01,
         5.63380282e-02, 4.42396313e-01, 1.18067227e-01, 1.04747808e-02,
         1.00000000e+01, 1.00000000e+01],
        [4.48979592e-01, 1.84210526e-02, 2.32758621e-01, 6.28571429e-01,
         3.16901408e-02, 6.77419355e-01, 1.07983193e-01, 1.04747808e-02,
         1.

### Dataloader

In [None]:
dataset_train = LabradorDataset(continuous=lab_values_train, categorical=lab_ids_train, tokenizer=tokenizer, max_len=max_len, masking_prob=MASKING)
dataset_test = LabradorDataset(continuous=lab_values_test, categorical=lab_ids_test, tokenizer=tokenizer, max_len=max_len, masking_prob=MASKING)

# Dataloader:
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)

{'[MASK]': 8, '[NULL]': 9, '[PAD]': 10, 'Crt': 0, 'Sod': 1, 'Bic': 2, 'Ure': 3, 'Hgb': 4, 'Plt': 5, 'Wbc': 6, 'Pot': 7}
{'[MASK]': 8, '[NULL]': 9, '[PAD]': 10, 'Crt': 0, 'Sod': 1, 'Bic': 2, 'Ure': 3, 'Hgb': 4, 'Plt': 5, 'Wbc': 6, 'Pot': 7}


### Model

In [None]:
mask_token = tokenizer.mask_token #-1
null_token = tokenizer.null_token #-2
pad_token = tokenizer.pad_token #-3
vocab_size = tokenizer.vocab_size()

In [None]:
model = Labrador(mask_token=mask_token, pad_token=pad_token, null_token=null_token, vocab_size=vocab_size, embedding_dim=embedding_dim, transformer_heads=transformer_heads, num_blocks=num_blocks, transformer_feedforward_dim=transformer_feedforward_dim, include_head=True, continuous_head_activation=continuous_head_activation, dropout_rate=dropout_rate)
model

Labrador(
  (categorical_embedding_layer): Embedding(14, 756)
  (continuous_embedding_layer): ContinuousEmbedding(
    (special_token_embeddings): Embedding(3, 756)
    (dense1): Linear(in_features=1, out_features=756, bias=True)
    (dense2): Linear(in_features=756, out_features=756, bias=True)
    (layernorm): LayerNorm((756,), eps=1e-05, elementwise_affine=True)
  )
  (projection_layer): Linear(in_features=1512, out_features=756, bias=True)
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (att): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=756, out_features=756, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=756, out_features=3072, bias=True)
        (1): ReLU()
        (2): Linear(in_features=3072, out_features=756, bias=True)
      )
      (layernorm1): LayerNorm((756,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((756,), eps=1e-05, elementwise_affine=True)
      (dropout1): D

### Train the model

In [None]:
# Loss functions
categorical_loss_fn = nn.CrossEntropyLoss()
continuous_loss_fn = nn.MSELoss()

In [None]:
# Train and validate the model
trained_model = train_labrador(model, train_loader, test_loader, categorical_loss_fn, continuous_loss_fn, optimizer=optimizer, num_epochs=num_epochs, device=device, save_model=save_model, model_path=model_path, categorical_loss_weight=categorical_loss_weight, continuous_loss_weight=continuous_loss_weight)

  0%|          | 1/3308 [00:00<22:40,  2.43it/s]

tensor([[[0.0674, 0.0647, 0.0985,  ..., 0.0599, 0.1077, 0.1267],
         [0.0658, 0.0559, 0.0849,  ..., 0.0861, 0.0947, 0.1106],
         [0.0743, 0.0714, 0.0937,  ..., 0.0733, 0.0990, 0.1159],
         ...,
         [0.0699, 0.0752, 0.0907,  ..., 0.0611, 0.1016, 0.1097],
         [0.0759, 0.0696, 0.0782,  ..., 0.0743, 0.1014, 0.1023],
         [0.0728, 0.0561, 0.0918,  ..., 0.0661, 0.0980, 0.1088]],

        [[0.0567, 0.0602, 0.0954,  ..., 0.0649, 0.0868, 0.1314],
         [0.0796, 0.0656, 0.0858,  ..., 0.0834, 0.0957, 0.1009],
         [0.0702, 0.0679, 0.1127,  ..., 0.0842, 0.0849, 0.1072],
         ...,
         [0.0760, 0.0590, 0.0786,  ..., 0.0893, 0.0764, 0.0976],
         [0.0704, 0.0695, 0.0809,  ..., 0.0708, 0.0882, 0.1215],
         [0.0808, 0.0641, 0.0803,  ..., 0.0757, 0.0802, 0.1162]],

        [[0.0603, 0.0641, 0.1025,  ..., 0.0661, 0.1133, 0.1247],
         [0.0680, 0.0702, 0.0878,  ..., 0.0783, 0.1028, 0.0907],
         [0.0786, 0.0670, 0.0943,  ..., 0.0739, 0.0951, 0.

100%|██████████| 3308/3308 [04:14<00:00, 13.02it/s]
  return torch._native_multi_head_attention(
  1%|          | 6/827 [00:00<00:48, 16.97it/s]

tensor([[[1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06],
         [1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06],
         [1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06],
         ...,
         [1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06],
         [1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06],
         [1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06]],

        [[1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06],
         [1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06],
         [1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.4443e-06, 1.5044e-06],
         ...,
         [1.9272e-01, 1.0978e-01, 1.1388e-01,  ..., 1.1461e-06,
          2.444

100%|██████████| 827/827 [00:18<00:00, 44.32it/s]


Epoch 1/15, Training Loss: 5.867214852982306, Validation Loss: nan


  2%|▏         | 79/3308 [00:06<04:25, 12.15it/s]


KeyboardInterrupt: 

### Test the model

In [None]:
dataset_test = LabradorDataset(continuous=lab_values_test, categorical=lab_ids_test, tokenizer=tokenizer, max_len=max_len, masking_prob=0)
# Dataloader:
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

{'[MASK]': 8, '[NULL]': 9, '[PAD]': 10, 'Crt': 0, 'Sod': 1, 'Bic': 2, 'Ure': 3, 'Hgb': 4, 'Plt': 5, 'Wbc': 6, 'Pot': 7}


In [None]:
from tqdm import tqdm
import torch
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import numpy as np

def test_model(model, test_loader, device, labs_list):
    model.to(device)
    model.eval()

    metrics = {lab: {'rmse': [], 'mae': [], 'r2': []} for lab in labs_list}

    with torch.no_grad():
        for lab in labs_list:
            print(f'Evaluating {lab}: ')
            lab_token = test_loader.dataset.tokenizer.vocab[lab]

            preds = []
            true_vals = []
            count = 0

            for batch in tqdm(test_loader, leave=True):
                lab_idx = (batch['input_ids'] == lab_token)
                batch['continuous'][lab_idx] = torch.tensor(test_loader.dataset.tokenizer.mask_token, dtype=torch.float32, device=device)

                input_ids = batch['input_ids'].to(device)
                continuous = batch['continuous'].to(device)
                attn_mask = batch['attention_mask'].to(device)
                labels_continuous = batch['labels_continuous'].to(device)
                
                if count == 0:
                    #print(f'Input ids: {input_ids}')
                    #print(f'Continuous: {continuous}')
                    #print(f'Attn mask: {attn_mask}')
                    #print(f'Labels continuous: {labels_continuous}')
                    pass

                outputs = model(input_ids, continuous, attn_mask=attn_mask)
                continuous_output = outputs['continuous_output'].squeeze(-1)
                if count == 0:
                    #print(f'Continuous output: {continuous_output}')
                    pass

                masked_cont_indices = (continuous == test_loader.dataset.tokenizer.mask_token).to(device)
                batch_preds = continuous_output[masked_cont_indices]
                batch_labels = labels_continuous[masked_cont_indices].to(device)

                preds.extend(batch_preds.tolist())
                true_vals.extend(batch_labels.tolist())
                
                if count == 0:
                    print(f'Preds: {batch_preds.tolist()}')
                    print(f'True vals: {batch_labels.tolist()}')
                    count += 1

            rmse = np.sqrt(mean_squared_error(true_vals, preds))
            mae = mean_absolute_error(true_vals, preds)
            r2 = r2_score(true_vals, preds)

            metrics[lab]['rmse'].append(rmse)
            metrics[lab]['mae'].append(mae)
            metrics[lab]['r2'].append(r2)

            print(f'RMSE: {rmse:.3f}')
            print(f'MAE: {mae:.3f}')
            print(f'R2: {r2:.3f}')
            print('-------------------')

    return metrics


In [None]:
test_model(model, test_loader, device, COLUMNS)

Evaluating Bic: 


  1%|          | 5/827 [00:00<00:49, 16.59it/s]

Preds: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
True vals: [0.4285714328289032, 0.5306122303009033, 0.4897959232330322, 0.5102040767669678, 0.5714285969734192, 0.4693877696990967, 0.5102040767669678, 0.3877550959587097, 0.3469387888908386, 0.5102040767669678, 0.5306122303009033, 0.4897959232330322, 0.4285714328289032, 0.4285714328289032, 0.5510203838348389, 0.4285714328289032, 0.5102040767669678, 0.4693877696990967, 0.3877550959587097, 0.6326530575752258, 0.4285714328289032, 0.5510203838348389, 0.5102040767669678, 0.4897959232330322, 0.5918367505073547, 0.40816327929496765, 0.4693877696990967, 0.3469387888908386, 0.5306122303009033, 0.5102040767669678, 0.5306122303009033, 0.5918367505073547, 0.3877550959587097,

 35%|███▍      | 287/827 [00:06<00:12, 42.75it/s]


KeyboardInterrupt: 