# Train CLMBR

This tutorial walks through the various steps to train a CLMBR model.

Note that CLMBR requires the gpu enabled version of FEMR. See the [README](https://github.com/som-shahlab/femr#how-to-install-femr-with-cuda-support) for the relevant instructions.

Training CLMBR is a three step process:

- Generating a dictionary
- Creating batches
- Training the model

In [1]:
import shutil
import os

TARGET_DIR = 'trash/tutorial_5'

if os.path.exists(TARGET_DIR):
    shutil.rmtree(TARGET_DIR)

os.mkdir(TARGET_DIR)

In [2]:
import os
import tempfile

EXTRACT_LOCATION = "input/extract"


"""
The first step of training CLMBR is creating a dictionary, that helps map codes to integers that can be used within a neural network.
"""

DICTIONARY_PATH = os.path.join(TARGET_DIR, "dictionary")
os.system(f"clmbr_create_dictionary {DICTIONARY_PATH} --data_path {EXTRACT_LOCATION}")

Banned 0 out of 4523
Got age statistics ... {"mean":834488.8237272064,"std":1516971.4691312013}


0

In [3]:
"""
The second step of training CLMBR is to prepare the batches that will actually get fed into the neural network.
"""

CLMBR_BATCHES = os.path.join(TARGET_DIR, "clmbr_batches")

os.system(
    f"clmbr_create_batches {CLMBR_BATCHES} --data_path {EXTRACT_LOCATION} --dictionary {DICTIONARY_PATH} --task clmbr --transformer_vocab_size 2048"
)


2023-05-30 15:09:02,240 [MainThread  ] [INFO ]  Preparing batches with Namespace(directory='trash/tutorial_5/clmbr_batches', data_path='input/extract', dictionary_path='trash/tutorial_5/dictionary', task='clmbr', transformer_vocab_size=2048, clmbr_survival_dictionary_path=None, labeled_patients_path=None, is_hierarchical=False, seed=97, val_start=70, test_start=85, batch_size=16384, note_embedding_data=None, limit_to_patients_file=None, limit_before_date=None, num_clmbr_tasks=8192)
2023-05-30 15:09:02,295 [MainThread  ] [INFO ]  Wrote config ...
2023-05-30 15:09:02,295 [MainThread  ] [INFO ]  Starting to load


When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out 

2023-05-30 15:09:02,765 [MainThread  ] [INFO ]  Loaded


When mapping codes, dropped 0 out of 2048


2023-05-30 15:09:03,003 [MainThread  ] [INFO ]  Number of train patients 1


0

In [4]:
"""
Given the batches, it is now possible to train CLMBR. By default it will train for 100 epochs, with early stopping.
"""

MODEL_PATH = os.path.join(TARGET_DIR, "clmbr_model")


assert 0 == os.system(
    f"clmbr_train_model {MODEL_PATH} --data_path {EXTRACT_LOCATION} --batches_path {CLMBR_BATCHES} --learning_rate 1e-4 --rotary_type per_head --num_batch_threads 3 --max_iter 10 --n_layers 1 --hidden_size 256 --n_heads 4 --intermediate_size 256"
)

2023-05-30 15:09:04,484 [MainThread  ] [INFO ]  Training model with Namespace(directory='trash/tutorial_5/clmbr_model', data_path='input/extract', batches_path='trash/tutorial_5/clmbr_batches', learning_rate=0.0001, rotary_type='per_head', clmbr_survival_dim=None, num_batch_threads=3, start_from_checkpoint=None, freeze_weights=False, token_dropout=0, internal_dropout=0, weight_decay=0, max_iter=10, hidden_size=256, intermediate_size=256, n_heads=4, n_layers=1, attention_width=512, dev_batches_path=None, linear_probe_head=None, early_stopping_window_steps=None)
2023-05-30 15:09:04,496 [MainThread  ] [INFO ]  Got config {'data_path': 'input/extract', 'batch_info_path': 'trash/tutorial_5/clmbr_batches/batch_info.msgpack', 'seed': 97, 'task': {'type': 'clmbr', 'vocab_size': 8192}, 'transformer': {'vocab_size': 2048, 'hidden_size': 256, 'intermediate_size': 256, 'n_heads': 4, 'n_layers': 1, 'rotary': 'per_head', 'attention_width': 496, 'internal_dropout': 0, 'is_hierarchical': False, 'note_

When mapping codes, dropped 0 out of 2048


2023-05-30 15:09:04,686 [MainThread  ] [INFO ]  Loaded batches 1 1
2023-05-30 15:09:05.861837: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:7: failed initializing StreamExecutor for CUDA device ordinal 7: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 34089730048
2023-05-30 15:09:05.868774: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:6: failed initializing StreamExecutor for CUDA device ordinal 6: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 34089730048
2023-05-30 15:09:05.875473: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:5: failed initializing StreamExecutor for CUDA device ordinal 5: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 34089730048
202

When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048


2023-05-30 15:09:18,595 [MainThread  ] [INFO ]  Loss scale DynamicLossScale(loss_scale=Array(32768., dtype=float32), counter=array(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1., dtype=float32))
2023-05-30 15:09:18,615 [MainThread  ] [INFO ]  Train loss {'loss': 9.364985466003418, 'c_statistic': -9.364985466003418}
2023-05-30 15:09:18,625 [MainThread  ] [INFO ]  Dev loss {'loss': 9.403003692626953, 'c_statistic': -9.403003692626953}
2023-05-30 15:09:18,768 [MainThread  ] [INFO ]  Continuing to train ...
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
2023-05-30 15:09:21,375 [MainThread  ] [INFO ]  [Step 0]
2023-05-30 15:09:21,391 [MainThread  ] [INFO ]  Loss scale DynamicLossScale(loss_scale=Array(32768., dtype=float32), counter=Array(2, dtype=int32), period=2000, factor=2, min_loss_scale=array(1., dtype=float32))
2023-05-30 15:09:21,629 [MainThread  ] [INFO ]  Train loss {'loss': 9.365042686462402, 'c_statistic': -9.365042686462402

Compiling the transformer ... (16384,) (4096,)
WITHOUT AGE
WITHOUT AGE
Compiling the transformer ... (16384,) (4096,)
WITHOUT AGE
Compiling the transformer ... (16384,) (256,)
WITHOUT AGE
Compiling the transformer ... (16384,) (4096,)
WITHOUT AGE
