In [1]:
import sys 
import os
import pandas as pd
import sklearn.metrics as skm
import numpy as np
import torch
import time
import torch.nn as nn
import datetime
import pickle

from datasets import Dataset
from tokenizers import *
from tokenizers.processors import BertProcessing
import json
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print(sys.version)
print(torch.__version__)

# maximum sequence length
max_length = 16 # note in paper, 128 used

# pad to max_length
def repeat_first_and_last(lst):
    first_element = lst[0]
    last_element = lst[-1]
    return [first_element] + lst + [last_element]

def adjust_visit_ids(ser):
    min_value = ser[0]
    adjusted_vis = [v - min_value + 1 for v in ser]
    return adjusted_vis

3.10.9 (main, Dec  7 2022, 13:15:23) [GCC 9.4.0]
2.0.1+cu117


# Read in and format data

In [2]:
# read in pkl file
disease = pd.read_pickle('DUMMY_EHRBERT_DATA.pkl')
print(len(disease))
disease

12


Unnamed: 0,patid,code_list,age_ids,visit_ids,year_ids,gender_ids,eth_ids,imd_ids
0,1,"[421235014, 275301017, 82343012, 396742015, 25...","[10, 10, 11, 11, 11]","[1, 1, 2, 3, 3]","[2, 2, 3, 3, 3]","[2, 2, 2, 2, 2]","[1, 1, 1, 1, 1]","[5, 5, 5, 5, 5]"
1,10,"[405621000000116, 281181000006116, 303867016, ...","[2, 3, 4, 4, 5]","[1, 2, 3, 4, 5]","[1, 5, 6, 6, 12]","[2, 2, 2, 2, 2]","[4, 4, 4, 4, 4]","[9, 9, 9, 9, 9]"
2,11,"[368051000006111, 303392010, 288711015, 167359...","[2, 3, 4, 4, 5]","[1, 2, 3, 4, 5]","[1, 5, 6, 6, 12]","[1, 1, 1, 1, 2]","[5, 5, 5, 5, 5]","[1, 1, 1, 1, 1]"
3,12,"[184656013, 906061000006116, 969331000006115]","[10, 10, 11]","[1, 1, 2]","[2, 2, 3]","[2, 2, 2]","[6, 6, 6]","[2, 2, 2]"
4,2,[348110010],[1],[1],[8],[1],[2],[4]
5,3,"[19794011, 19794011, 19794011, 645331000006112]","[12, 12, 13, 13]","[1, 2, 3, 4]","[9, 10, 11, 12]","[2, 2, 2, 2]","[3, 3, 3, 3]","[3, 3, 3, 3]"
6,4,"[1484866013, 736361000006117, 411198017, 40562...","[6, 6, 7, 7, 7, 8, 9, 9]","[1, 1, 2, 2, 3, 4, 5, 5]","[2, 2, 3, 3, 3, 4, 7, 7]","[1, 1, 1, 1, 1, 1, 1, 1]","[4, 4, 4, 4, 4, 4, 4, 4]","[2, 2, 2, 2, 2, 2, 2, 2]"
7,5,"[304875018, 1786700015, 253940015, 253940015, ...","[2, 3, 4, 4, 5]","[1, 2, 3, 4, 5]","[1, 5, 6, 6, 12]","[1, 1, 1, 1, 1]","[5, 5, 5, 5, 5]","[6, 6, 6, 6, 6]"
8,6,"[294870013, 294603013, 2871720012, 31102100000...","[10, 10, 11, 11, 11]","[1, 1, 2, 3, 3]","[2, 2, 3, 3, 3]","[2, 2, 2, 2, 2]","[1, 1, 1, 1, 1]","[8, 8, 8, 8, 8]"
9,7,[136211012],[1],[1],[8],[1],[2],[7]


### Format data - truncation and padding

In [3]:
# Truncate codes over max_len to max_len
# using adjusted max_seq-length to 2 less than max length to account for addition of CLS and SEP

var_list = ['code_list','age_ids','visit_ids','year_ids','gender_ids','eth_ids','imd_ids']
max_seq_length = max_length-2

for var in var_list:
    disease[var] = disease[var].apply(lambda x: x[-max_seq_length:])
    
# Remove list from within code column:
disease['code_list'] = [','.join(map(str, l)) for l in disease['code_list']]

# Remove commas from within code column:
disease['code_list'] = disease['code_list'].str.replace(',',' ')

# Repeat first and last elements (to account for CLS and SEP tokens) then pad with zeros
var_list = ['age_ids','visit_ids','year_ids','gender_ids','eth_ids','imd_ids']
for var in var_list:
    disease[var] = disease[var].apply(lambda x: repeat_first_and_last(x)) # repeating end elements to match [CLS] and [SEP]
    disease[var] = disease[var].apply(lambda x: x + [0] * (max_length - len(x)) if len(x) < max_length else x[:max_length])

# recode visit number/ID to start at 1 where relevant (e.g. if truncation results in starting at value >1)
# Apply to visit_ids column
disease['visit_ids'] = disease['visit_ids'].apply(adjust_visit_ids)
assert len(disease.loc[disease['visit_ids'].apply(lambda x: x[0] > 1)]) == 0
disease

Unnamed: 0,patid,code_list,age_ids,visit_ids,year_ids,gender_ids,eth_ids,imd_ids
0,1,421235014 275301017 82343012 396742015 256478018,"[10, 10, 10, 11, 11, 11, 11, 0, 0, 0, 0, 0, 0,...","[1, 1, 1, 2, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,10,405621000000116 281181000006116 303867016 2592...,"[2, 2, 3, 4, 4, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 2, 3, 4, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 5, 6, 6, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0...","[2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
2,11,368051000006111 303392010 288711015 1673591000...,"[2, 2, 3, 4, 4, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 2, 3, 4, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 5, 6, 6, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
3,12,184656013 906061000006116 969331000006115,"[10, 10, 10, 11, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[6, 6, 6, 6, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
4,2,348110010,"[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
5,3,19794011 19794011 19794011 645331000006112,"[12, 12, 12, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 2, 3, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[9, 9, 10, 11, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0,...","[2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
6,4,1484866013 736361000006117 411198017 405621000...,"[6, 6, 6, 7, 7, 7, 8, 9, 9, 9, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 2, 2, 3, 4, 5, 5, 5, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 3, 3, 3, 4, 7, 7, 7, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]","[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0]"
7,5,304875018 1786700015 253940015 253940015 14848...,"[2, 2, 3, 4, 4, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 2, 3, 4, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 5, 6, 6, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[6, 6, 6, 6, 6, 6, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
8,6,294870013 294603013 2871720012 311021000006113...,"[10, 10, 10, 11, 11, 11, 11, 0, 0, 0, 0, 0, 0,...","[1, 1, 1, 2, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[8, 8, 8, 8, 8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
9,7,136211012,"[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[7, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [4]:
# convert to dataset for transformer model
disease.index=disease['patid']
train_dataset = Dataset.from_pandas(pd.DataFrame(disease[['code_list','age_ids','visit_ids','year_ids',
                                                   'gender_ids','eth_ids','imd_ids']]), preserve_index=True)
train_dataset

Dataset({
    features: ['code_list', 'age_ids', 'visit_ids', 'year_ids', 'gender_ids', 'eth_ids', 'imd_ids', 'patid'],
    num_rows: 12
})

### Tokenizer
Using WordLevel tokenizer (whole word) and add 5 special tokens

In [5]:
# save the dataset to dummy_train.txt to use as input for tokenizer

def dataset_to_text(dataset, output_filename="data.txt"):
    with open(output_filename, "w") as f:
        for t in dataset["code_list"]:
            print(t, file=f)
        
dataset_to_text(train_dataset, "dummy_train.txt")
files = ["dummy_train.txt"]

# Train tokenizer
tokenizer = Tokenizer(models.WordLevel(unk_token="[UNK]"))
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()

tokenizer.post_processor = BertProcessing(("SEP", 2), ("CLS", 1))

special_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
trainer = trainers.WordLevelTrainer(special_tokens=special_tokens)

tokenizer.train(files=files, trainer=trainer)

# to use, need to use wrapped tokenizer
from transformers import PreTrainedTokenizerFast
wrapped_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    # tokenizer_file="ehrbert_tokenizer.json", # Can load directly, otherwise
    bos_token="[CLS]",
    eos_token="[SEP]",
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]",
)
wrapped_tokenizer.save_pretrained("ehrbert_tokenizer")

# Now apply, with truncation and padding
def encode_with_truncation(examples):
  """Mapping function to tokenize the sentences passed with truncation"""
  return wrapped_tokenizer(examples["code_list"], truncation=True, padding="max_length",
                   max_length=max_length, return_special_tokens_mask=True)

# tokenizing the train dataset - batched here set to false
train_dataset = train_dataset.map(encode_with_truncation, batched=False)
train_dataset.set_format("torch")
train_dataset

2023-12-24 13:05:34.756402: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-24 13:05:34.804247: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


  0%|          | 0/12 [00:00<?, ?ex/s]

Dataset({
    features: ['code_list', 'age_ids', 'visit_ids', 'year_ids', 'gender_ids', 'eth_ids', 'imd_ids', 'patid', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
    num_rows: 12
})

In [6]:
# set train/test split:
d = train_dataset.train_test_split(test_size=0.2, seed=42)
d["train"], d["test"]


(Dataset({
     features: ['code_list', 'age_ids', 'visit_ids', 'year_ids', 'gender_ids', 'eth_ids', 'imd_ids', 'patid', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
     num_rows: 9
 }),
 Dataset({
     features: ['code_list', 'age_ids', 'visit_ids', 'year_ids', 'gender_ids', 'eth_ids', 'imd_ids', 'patid', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
     num_rows: 3
 }))

In [7]:
# Determine unique values for each variable
AGE_SIZE = len({x for l in (disease['age_ids'].to_list()) for x in l})
print(AGE_SIZE)
VISIT_SIZE = len({x for l in (disease['visit_ids'].to_list()) for x in l})
print(VISIT_SIZE)
#assert VISIT_SIZE == (max_seq_length+1) # check should be max_len-2, plus 1 for 0 term (padding)
YEAR_SIZE = len({x for l in (disease['year_ids'].to_list()) for x in l})
print(YEAR_SIZE)
GENDER_SIZE = len({x for l in (disease['gender_ids'].to_list()) for x in l})
print(GENDER_SIZE)
ETH_SIZE = len({x for l in (disease['eth_ids'].to_list()) for x in l})
print(ETH_SIZE)
IMD_SIZE = len({x for l in (disease['imd_ids'].to_list()) for x in l})
print(IMD_SIZE)


14
6
13
3
7
11


### EHR-BERT model

In [8]:
from ehr_bert import EHRBertForMaskedLM, EHRBertModel, EHRBertEmbeddings, EHRBertConfig
from transformers import DataCollatorForLanguageModeling, TrainingArguments, Trainer

In [9]:
# EHR-BERT model configuration

model_config = EHRBertConfig(vocab_size=wrapped_tokenizer.vocab_size, # default=30522
                          max_position_embeddings=max_length, # default=512
                          num_hidden_layers=6, # default=12
                          num_attention_heads=12, # default=12
                          hidden_size=288,    #default=768
                          intermediate_size=512,   #default=3072
                          age_size=AGE_SIZE,
                          visit_size=VISIT_SIZE,
                          gender_size=GENDER_SIZE,
                          year_size=YEAR_SIZE,
                          eth_size=ETH_SIZE,
                          imd_size=IMD_SIZE,
                          segment_include=False,
                          position_include=False
                         )

model = EHRBertForMaskedLM(config=model_config)

# create data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=wrapped_tokenizer, mlm=True, mlm_probability=0.15)

In [10]:
# define accuracy metric during training
from datasets import load_metric
acc = load_metric('accuracy')
prec = load_metric('precision')
f1 = load_metric("f1")

def compute_metrics(eval_pred):
    labels = eval_pred.label_ids
    predictions = eval_pred.predictions[0]

    indices = [[i for i, x in enumerate(labels[row]) if x != -100] for row in range(len(labels))]
    labels = [labels[row][indices[row]] for row in range(len(labels))]
    labels = [item for sublist in labels for item in sublist]
    predictions = [predictions[row][indices[row]] for row in range(len(predictions))]
    predictions = [item for sublist in predictions for item in sublist]

    results1 = acc.compute(predictions=predictions, references=labels)
    results3 = prec.compute(predictions=predictions, references=labels, average='weighted')
    results5 = f1.compute(predictions=predictions, references=labels, average='weighted')
    
    return {"Accuracy": results1["accuracy"],
            "Precision-weighted": results3["precision"],
            "F1-weighted": results5["f1"]}

def preprocess_logits_for_metrics(logits, labels):
    """
    This function reduces GPU memory overload.
    """
    pred_ids = torch.argmax(logits, dim=-1)
    return pred_ids, labels

  acc = load_metric('accuracy')


In [11]:
#set epochs and step size - num_evals = number of evaluations over all steps
batch_size = 1
grad_steps = 1
num_evals = 5
epochs = 10

eval_steps = int(((np.floor((len(train_dataset)/batch_size)/grad_steps))*epochs)/num_evals)
eval_steps

24

In [12]:
# define training arguments
training_args = TrainingArguments(
    output_dir='ehrbert',          # path to save model checkpoint
    evaluation_strategy="steps",   # evaluate each `logging_steps` steps, or at epoch
    overwrite_output_dir=True,      
    num_train_epochs=epochs,        
    per_device_train_batch_size=batch_size, 
    gradient_accumulation_steps=grad_steps, 
    per_device_eval_batch_size=batch_size,   
    learning_rate=3e-5,            # default = 5e-5; BEHRT uses 3e-5, MedBERT uses 5e-5
    eval_accumulation_steps = 1000,  
    logging_steps=eval_steps,      # evaluate every X steps
    save_steps=eval_steps,          
    fp16=True                       
)

In [13]:
# initialize the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=d["train"],
    eval_dataset=d["test"],
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics = preprocess_logits_for_metrics, # this ensures only output relevant tensors
    tokenizer = wrapped_tokenizer
)

Using cuda_amp half precision backend


In [14]:
# train the model
trainer.train()

The following columns in the training set don't have a corresponding argument in `EHRBertForMaskedLM.forward` and have been ignored: special_tokens_mask, code_list, patid. If special_tokens_mask, code_list, patid are not expected by `EHRBertForMaskedLM.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 9
  Num Epochs = 10
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 90
  Number of trainable parameters = 3891404
You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Accuracy,Precision-weighted,F1-weighted
24,2.6642,3.434896,0.0,0.0,0.0
48,2.1801,,0.0,0.0,0.0
72,3.1294,,0.0,0.0,0.0


The following columns in the evaluation set don't have a corresponding argument in `EHRBertForMaskedLM.forward` and have been ignored: special_tokens_mask, code_list, patid. If special_tokens_mask, code_list, patid are not expected by `EHRBertForMaskedLM.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3
  Batch size = 1
  _warn_prf(average, modifier, msg_start, len(result))
Saving model checkpoint to ehrbert/checkpoint-24
Configuration saved in ehrbert/checkpoint-24/config.json
Model weights saved in ehrbert/checkpoint-24/pytorch_model.bin
tokenizer config file saved in ehrbert/checkpoint-24/tokenizer_config.json
Special tokens file saved in ehrbert/checkpoint-24/special_tokens_map.json
The following columns in the evaluation set don't have a corresponding argument in `EHRBertForMaskedLM.forward` and have been ignored: special_tokens_mask, code_list, patid. If special_tokens_mask, code_list, patid are not expected by `EHRBertForMaskedLM.fo

TrainOutput(global_step=90, training_loss=2.8057572258843315, metrics={'train_runtime': 4.0196, 'train_samples_per_second': 22.39, 'train_steps_per_second': 22.39, 'total_flos': 33377875200.0, 'train_loss': 2.8057572258843315, 'epoch': 10.0})