# Train MOTOR

This tutorial walks through the various steps to train a MOTOR model.

Note that MOTOR requires the gpu enabled version of FEMR. See the [README](https://github.com/som-shahlab/femr#how-to-install-femr-with-cuda-support) for the relevant instructions.

Training MOTOR is a three step process:

- Generating a survival dictionary
- Generating a dictionary
- Creating batches
- Training the model

In [1]:
import shutil
import os

TARGET_DIR = 'trash/tutorial_7'

if os.path.exists(TARGET_DIR):
    shutil.rmtree(TARGET_DIR)

os.mkdir(TARGET_DIR)

In [2]:
import os
import tempfile

EXTRACT_LOCATION = "input/extract"


"""
The first step of training CLMBR is creating a dictionary, that helps map codes to integers that can be used within a neural network.
"""

DICTIONARY_PATH = os.path.join(TARGET_DIR, "dictionary")
os.system(f"clmbr_create_dictionary {DICTIONARY_PATH} --data_path {EXTRACT_LOCATION}")

Banned 0 out of 4523
Got age statistics ... {"mean":834488.8237272066,"std":1516971.4691312015}


0

In [3]:
import os
import tempfile

EXTRACT_LOCATION = "input/extract"


"""
The first step of training CLMBR is creating a dictionary, that helps map codes to integers that can be used within a neural network.
"""

SURVIVAL_DICTIONARY_PATH = os.path.join(TARGET_DIR, "survival_dictionary")
os.system(f"clmbr_create_survival_dictionary {SURVIVAL_DICTIONARY_PATH} --data_path {EXTRACT_LOCATION} --num_buckets 8 --size 1024")

Banned 0 out of 4523
Starting to process 
RAEDY
Got total weight 2.85142
0 175680 332640 492480 658080 822240 999360 1244160 
Got total weight 1024
0 175680 332640 492480 658080 822240 999360 1244160 
Got total weight 2.85142
0 105120 175680 250560 332640 411840 492480 576000 658080 735840 822240 907200 999360 1110240 1244160 1441440 
Got total weight 1024
0 105120 175680 250560 332640 411840 492480 576000 658080 735840 822240 907200 999360 1110240 1244160 1441440 


0

In [4]:
"""
The second step of training CLMBR is to prepare the batches that will actually get fed into the neural network.
"""

SURVIVAL_CLMBR_BATCHES = os.path.join(TARGET_DIR, "clmbr_batches")

command =  f"clmbr_create_batches {SURVIVAL_CLMBR_BATCHES} --data_path {EXTRACT_LOCATION} --dictionary {DICTIONARY_PATH} --task survival_clmbr --transformer_vocab_size 2048 --clmbr_survival_dictionary_path  {SURVIVAL_DICTIONARY_PATH} --is_hierarchical"

print(command)

os.system(command)


clmbr_create_batches trash/tutorial_7/clmbr_batches --data_path input/extract --dictionary trash/tutorial_7/dictionary --task survival_clmbr --transformer_vocab_size 2048 --clmbr_survival_dictionary_path  trash/tutorial_7/survival_dictionary --is_hierarchical


2023-07-21 16:53:11,495 [MainThread  ] [INFO ]  Preparing batches with Namespace(directory='trash/tutorial_7/clmbr_batches', data_path='input/extract', dictionary_path='trash/tutorial_7/dictionary', task='survival_clmbr', transformer_vocab_size=2048, clmbr_survival_dictionary_path='trash/tutorial_7/survival_dictionary', labeled_patients_path=None, is_hierarchical=True, seed=97, val_start=80, test_start=85, batch_size=16384, note_embedding_data=None, limit_to_patients_file=None, limit_before_date=None, num_clmbr_tasks=8192)
2023-07-21 16:53:11,535 [MainThread  ] [INFO ]  Wrote config ...
2023-07-21 16:53:11,536 [MainThread  ] [INFO ]  Starting to load


When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When m

2023-07-21 16:53:12,293 [MainThread  ] [INFO ]  Loaded


When mapping codes, dropped 0 out of 2048


2023-07-21 16:53:12,642 [MainThread  ] [INFO ]  Number of train patients 1


0

In [6]:
"""
Given the batches, it is now possible to train CLMBR. By default it will train for 100 epochs, with early stopping.
"""

MODEL_PATH = os.path.join(TARGET_DIR, "survival_clmbr_model")


assert 0 == os.system(
    f"clmbr_train_model {MODEL_PATH} --data_path {EXTRACT_LOCATION} --batches_path {SURVIVAL_CLMBR_BATCHES} --learning_rate 1e-4 --rotary_type per_head --num_batch_threads 3 --max_iter 10 --n_layers 1 --hidden_size 256 --n_heads 4 --intermediate_size 256 --clmbr_survival_dim  512"
)

2023-07-21 16:53:53,445 [MainThread  ] [INFO ]  Training model with Namespace(directory='trash/tutorial_7/survival_clmbr_model', data_path='input/extract', batches_path='trash/tutorial_7/clmbr_batches', learning_rate=0.0001, rotary_type='per_head', clmbr_survival_dim=512, num_batch_threads=3, start_from_checkpoint=None, freeze_weights=False, token_dropout=0, internal_dropout=0, weight_decay=0, max_iter=10, hidden_size=256, intermediate_size=256, n_heads=4, n_layers=1, attention_width=512, dev_batches_path=None, linear_probe=None, early_stopping_window_steps=None, with_age_beta=False)
2023-07-21 16:53:53,464 [MainThread  ] [INFO ]  Got config {'data_path': 'input/extract', 'batch_info_path': 'trash/tutorial_7/clmbr_batches/batch_info.msgpack', 'seed': 97, 'task': {'type': 'survival_clmbr', 'num_time_bins': 8, 'num_codes': 1024, 'dim': 512, 'time_bins': (0, 175680, 332640, 492480, 658080, 822240, 999360, 1244160)}, 'transformer': {'vocab_size': 2048, 'hidden_size': 256, 'intermediate_siz

When mapping codes, dropped 0 out of 2048


2023-07-21 16:53:53,717 [MainThread  ] [INFO ]  Loaded batches 1 1
2023-07-21 16:53:54,313 [MainThread  ] [INFO ]  Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Host Interpreter CUDA
2023-07-21 16:53:54,315 [MainThread  ] [INFO ]  Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
2023-07-21 16:53:54,315 [MainThread  ] [INFO ]  Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
2023-07-21 16:53:55,437 [MainThread  ] [INFO ]  Got dummy batch {'num_indices': ((), dtype('int32'), StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)), 'num_patients': ((), dtype('int32'), StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)), 'offsets': ((512,), dtype('uint32'), Str

When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048
When mapping codes, dropped 0 out of 2048


2023-07-21 16:54:03,078 [MainThread  ] [INFO ]  Loss scale DynamicLossScale(loss_scale=Array(32768., dtype=float32), counter=array(0, dtype=int32), period=2000, factor=2, min_loss_scale=array(1., dtype=float32))
2023-07-21 16:54:03,252 [MainThread  ] [INFO ]  Train loss {'loss': 0.08670195192098618, 'loss2': 0.08670195192098618, 'c_statistic': -0.08670195192098618}
2023-07-21 16:54:03,274 [MainThread  ] [INFO ]  Dev loss {'loss': 0.0877375677227974, 'loss2': 0.0877375677227974, 'c_statistic': -0.0877375677227974}
2023-07-21 16:54:03,495 [MainThread  ] [INFO ]  Continuing to train ...
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
2023-07-21 16:54:08,809 [MainThread  ] [INFO ]  [Step 0]
2023-07-21 16:54:11,635 [MainThread  ] [INFO ]  Loss scale DynamicLossScale(loss_scale=Array(32768., dtype=float32), counter=Array(2, dtype=int32), period=2000, factor=2, min_loss_scale=array(1., dtype=float32))
2023-07-21 16:54:12,956 [MainThread  ] [INFO ]  Train l

Compiling the transformer ... (16384,) (4096,)
Compiling the transformer ... (16384,) (4096,)
Compiling the transformer ... (16384,) (256,)
Compiling the transformer ... (16384,) (4096,)
