# Train CLMBR

This tutorial walks through the various steps to train a CLMBR model.

Note that CLMBR requires the gpu enabled version of FEMR. See the [README](https://github.com/som-shahlab/femr#how-to-install-femr-with-cuda-support) for the relevant instructions.

Training CLMBR is a three step process:

- Generating a dictionary
- Creating batches
- Training the model

In [1]:
import shutil
import os

TARGET_DIR = 'trash/tutorial_5'

if os.path.exists(TARGET_DIR):
    shutil.rmtree(TARGET_DIR)

os.mkdir(TARGET_DIR)

In [2]:
import os
import tempfile

EXTRACT_LOCATION = "input/extract"


"""
The first step of training CLMBR is creating a dictionary, that helps map codes to integers that can be used within a neural network.
"""

DICTIONARY_PATH = os.path.join(TARGET_DIR, "dictionary")
os.system(f"clmbr_create_dictionary {DICTIONARY_PATH} --data_path {EXTRACT_LOCATION}")

Banned 0 out of 4523
Got age statistics ... {"mean":834488.8237272067,"std":1516971.4691312027}


0

In [3]:
"""
The second step of training CLMBR is to prepare the batches that will actually get fed into the neural network.
"""

CLMBR_BATCHES = os.path.join(TARGET_DIR, "clmbr_batches")

os.system(
    f"clmbr_create_batches {CLMBR_BATCHES} --data_path {EXTRACT_LOCATION} --dictionary {DICTIONARY_PATH} --task clmbr --transformer_vocab_size 2048"
)


2023-05-24 01:06:18,671 [MainThread  ] [INFO ]  Preparing batches with Namespace(directory='trash/tutorial_5/clmbr_batches', data_path='input/extract', dictionary_path='trash/tutorial_5/dictionary', task='clmbr', transformer_vocab_size=2048, clmbr_survival_dictionary_path=None, labeled_patients_path=None, is_hierarchical=False, seed=97, val_start=70, test_start=85, batch_size=16384, note_embedding_data=None, limit_to_patients_file=None, limit_before_date=None)
2023-05-24 01:06:18,686 [MainThread  ] [INFO ]  Wrote config ...
2023-05-24 01:06:18,686 [MainThread  ] [INFO ]  Starting to load
2023-05-24 01:06:18,772 [MainThread  ] [INFO ]  Loaded
2023-05-24 01:06:18,807 [MainThread  ] [INFO ]  Number of train patients 1


When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped When mapping codes, dropped 0 out of 0 out of 20482048

When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out 

0

In [4]:
"""
Given the batches, it is now possible to train CLMBR. By default it will train for 100 epochs, with early stopping.
"""

MODEL_PATH = os.path.join(TARGET_DIR, "clmbr_model")


assert 0 == os.system(
    f"clmbr_train_model {MODEL_PATH} --data_path {EXTRACT_LOCATION} --batches_path {CLMBR_BATCHES} --learning_rate 1e-4 --rotary_type per_head --num_batch_threads 3 --max_iter 10 --n_layers 1 --hidden_size 256 --n_heads 4 --intermediate_size 256"
)

2023-05-24 01:06:19,829 [MainThread  ] [INFO ]  Training model with Namespace(directory='trash/tutorial_5/clmbr_model', data_path='input/extract', batches_path='trash/tutorial_5/clmbr_batches', learning_rate=0.0001, rotary_type='per_head', clmbr_survival_dim=None, num_batch_threads=3, start_from_checkpoint=None, freeze_weights=False, token_dropout=0, internal_dropout=0, weight_decay=0, max_iter=10, hidden_size=256, intermediate_size=256, n_heads=4, n_layers=1, attention_width=512, dev_batches_path=None, linear_probe_head=None, early_stopping_window_steps=None)
2023-05-24 01:06:19,840 [MainThread  ] [INFO ]  Got config {'data_path': 'input/extract', 'batch_info_path': 'trash/tutorial_5/clmbr_batches/batch_info.msgpack', 'seed': 97, 'task': {'type': 'clmbr', 'vocab_size': 8192}, 'transformer': {'vocab_size': 2048, 'hidden_size': 256, 'intermediate_size': 256, 'n_heads': 4, 'n_layers': 1, 'rotary': 'per_head', 'attention_width': 496, 'internal_dropout': 0, 'is_hierarchical': False, 'note_

When mapping codes, dropped 0 out of 2048


2023-05-24 01:06:20,139 [MainThread  ] [INFO ]  Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
2023-05-24 01:06:20,140 [MainThread  ] [INFO ]  Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
2023-05-24 01:06:20,140 [MainThread  ] [INFO ]  Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
2023-05-24 01:06:21,606 [MainThread  ] [INFO ]  Got dummy batch {'num_indices': ((), dtype('int32'), StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)), 'num_patients': ((), dtype('int32'), StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)), 'offsets': ((512,), dtype('uint32'), StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)), 'patie

When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048


2023-05-24 01:06:26,757 [MainThread  ] [INFO ]  Continuing to train ...
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
2023-05-24 01:06:29,471 [MainThread  ] [INFO ]  Loss scale DynamicLossScale(loss_scale=Array(32768., dtype=float32), counter=Array(1, dtype=int32), period=2000, factor=2, min_loss_scale=array(1., dtype=float32))
2023-05-24 01:06:30,226 [MainThread  ] [INFO ]  Train loss {'loss': 9.358352661132812, 'c_statistic': -9.358352661132812}
2023-05-24 01:06:30,656 [MainThread  ] [INFO ]  Dev loss {'loss': 9.386465072631836, 'c_statistic': -9.386465072631836}
2023-05-24 01:06:30,692 [MainThread  ] [INFO ]  Continuing to train ...
2023-05-24 01:06:32,092 [MainThread  ] [INFO ]  [Step 0]
2023-05-24 01:06:32,096 [MainThread  ] [INFO ]  Loss scale DynamicLossScale(loss_scale=Array(32768., dtype=float32), counter=Array(3, dtype=int32), period=2000, factor=2, min_loss_scale=array(1., dtype=float32))
2023-05-24 01:06:32,108 [MainThread  ] [INFO ]  

Compiling the transformer ... (16384,) (4096,)
Compiling the transformer ... (16384,) (4096,)
Compiling the transformer ... (16384,) (256,)
Compiling the transformer ... (16384,) (4096,)
