# Train MOTOR

This tutorial walks through the various steps to train a MOTOR model.

Training MOTOR is a four step process:

- Training a tokenizer
- Prefitting MOTOR
- Preparing batches
- Training the model

In [1]:
import shutil
import os

os.environ["HF_DATASETS_CACHE"] = '/share/pi/nigam/projects/zphuo/.cache'
os.environ["WANDB_DISABLED"] = "true"

TARGET_DIR = 'trash/tutorial_6_INSEPCT'

from_pretrained = True
num_proc = 20

In [2]:
if not from_pretrained:
    if os.path.exists(TARGET_DIR):
        shutil.rmtree(TARGET_DIR)

    os.mkdir(TARGET_DIR)
    os.mkdir(os.path.join(TARGET_DIR, 'motor_model'))

In [3]:
import datasets
import femr.index
import femr.splits

# First, we want to split our dataset into train, valid, and test
# We do this by calling our split functionality twice

# dataset = datasets.Dataset.from_parquet('input/meds/data/*')
parquet_folder = '/share/pi/nigam/projects/zphuo/data/PE/inspect/timelines_smallfiles_meds/data_subset/*'
dataset = datasets.Dataset.from_parquet(parquet_folder)


index = femr.index.PatientIndex(dataset, num_proc=num_proc)
main_split = femr.splits.generate_hash_split(index.get_patient_ids(), 97, frac_test=0.15)


# Note that we want to save this to the target directory since this is important information

main_split.save_to_csv(os.path.join(TARGET_DIR, "motor_model", "main_split.csv"))

import pandas as pd
label_csv_subset = '/share/pi/nigam/projects/zphuo/data/PE/inspect/timelines_smallfiles_meds/cohort_0.2.0_master_file_anon_subset.csv'
label_df = pd.read_csv(label_csv_subset)
label_df = label_df[['patient_id', 'split', ]]
inspect_split_csv = '/share/pi/nigam/projects/zphuo/repos/femr/tutorials/trash/tutorial_6_INSEPCT/motor_model/main_split.csv'
label_df.to_csv(inspect_split_csv, index=False)

train_split = femr.splits.generate_hash_split(main_split.train_patient_ids, 87, frac_test=0.15)

# print(train_split.train_patient_ids)
# print(train_split.test_patient_ids)

main_dataset = main_split.split_dataset(dataset, index)
train_dataset = train_split.split_dataset(main_dataset['train'], femr.index.PatientIndex(main_dataset['train'], num_proc=num_proc))

# print(train_dataset)

Map (num_proc=20):   0%|          | 0/1916 [00:00<?, ? examples/s]

Map (num_proc=20):   0%|          | 0/1639 [00:00<?, ? examples/s]

In [4]:
import femr.models.tokenizer
from femr.models.tokenizer import FEMRTokenizer
import pickle

# First, we need to train a tokenizer
# Note, we need to use a hierarchical tokenizer for MOTOR

with open('input/meds/ontology.pkl', 'rb') as f:
    ontology = pickle.load(f)



In [5]:
if not from_pretrained:
    tokenizer = femr.models.tokenizer.train_tokenizer(
        main_dataset['train'], vocab_size=128, is_hierarchical=True, num_proc=num_proc, ontology=ontology)

    # Save the tokenizer to the same directory as the model
    tokenizer.save_pretrained(os.path.join(TARGET_DIR, "motor_model"))

else:
    # load pretrained tokenizer
    tokenizer = femr.models.tokenizer.FEMRTokenizer.from_pretrained(os.path.join(TARGET_DIR, "motor_model"), ontology=ontology)

In [6]:
import femr.models.tasks

if 'subset' in parquet_folder:
    num_tasks = 39
else:
    num_tasks = 64

# Second, we need to prefit the MOTOR model. This is necessary because piecewise exponential models are unstable without an initial fit

motor_task = femr.models.tasks.MOTORTask.fit_pretraining_task_info(
    main_dataset['train'], tokenizer, num_tasks=num_tasks, num_bins=4, final_layer_size=32, num_proc=num_proc)


# It's recommended to save this with pickle to avoid recomputing since it's an expensive operation


Map (num_proc=20):   0%|          | 0/1639 [00:00<?, ? examples/s]

ValueError: Couldn't find patient birthdate -- Patient has no birth events (which must be {meds.birth_code}): [{'time': datetime.datetime(2123, 1, 1, 0, 0), 'measurements': [{'code': 'SNOMED/3950001', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'person', 'source_code': None, 'unit': None, 'visit_id': None}}]}, {'time': datetime.datetime(2123, 1, 1, 23, 59), 'measurements': [{'code': 'Race/2', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'person', 'source_code': 'Declines to State | Asian', 'unit': None, 'visit_id': None}}, {'code': 'Gender/F', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'person', 'source_code': '1 | 1', 'unit': None, 'visit_id': None}}, {'code': 'Ethnicity/Not Hispanic', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'person', 'source_code': 'Non-Hispanic/Non-Latino | Non-Hispanic/Non-Latino', 'unit': None, 'visit_id': None}}]}, {'time': datetime.datetime(2152, 3, 18, 15, 15), 'measurements': [{'code': 'CPT4/87081', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'measurement', 'source_code': 'GC CULTURE SCREEN', 'unit': None, 'visit_id': None}}, {'code': 'LOINC/28570-0', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'note', 'source_code': None, 'unit': None, 'visit_id': None}}, {'code': 'LOINC/28570-0', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'note', 'source_code': None, 'unit': None, 'visit_id': None}}, {'code': 'SNOMED/407707008', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'measurement', 'source_code': 'CHLAMYDIA MOLECULAR DETECTION', 'unit': None, 'visit_id': None}}]}, {'time': datetime.datetime(2152, 11, 3, 23, 59), 'measurements': [{'code': 'LOINC/11506-3', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'note', 'source_code': None, 'unit': None, 'visit_id': '28512886.0'}}, {'code': 'LOINC/11506-3', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'note', 'source_code': None, 'unit': None, 'visit_id': '28512886.0'}}]}, {'time': datetime.datetime(2152, 12, 11, 11, 50), 'measurements': [{'code': 'LOINC/28570-0', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'note', 'source_code': None, 'unit': None, 'visit_id': '28512886.0'}}, {'code': 'SNOMED/117010004', 'text_value': None, 'numeric_value': None, 'datetime_value': None, 'metadata': {'end': None, 'omop_table': 'measurement', 'source_code': 'URINE CULTURE', 'unit': None, 'visit_id': '28512886.0'}}]}]

In [None]:
import femr.models.processor
import femr.models.tasks

# Third, we need to create batches. 

processor = femr.models.processor.FEMRBatchProcessor(tokenizer, motor_task)

# We can do this one patient at a time
print("Convert a single patient")
example_batch = processor.collate([processor.convert_patient(train_dataset['train'][0], tensor_type='pt')])

# print("Convert batches")
# # But generally we want to convert entire datasets
# train_batches = processor.convert_dataset(train_dataset, tokens_per_batch=32, num_proc=num_proc, min_samples_per_batch=1)
# print("Convert batches to pytorch")
# # Convert our batches to pytorch tensors
# train_batches.set_format("pt")
# print("Done")

In [None]:
import transformers

# Finally, given the batches, we can train CLMBR.
# We can use huggingface's trainer to do this.

transformer_config = femr.models.transformer.FEMRTransformerConfig(
    vocab_size=tokenizer.vocab_size, 
    is_hierarchical=tokenizer.is_hierarchical, 
    n_layers=2,
    hidden_size=64, 
    intermediate_size=64*2,
    n_heads=8,
)

config = femr.models.transformer.FEMRModelConfig.from_transformer_task_configs(transformer_config, motor_task.get_task_config())

model = femr.models.transformer.FEMRModel(config)

collator = processor.collate

trainer_config = transformers.TrainingArguments(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,

    output_dir='tmp_trainer',
    remove_unused_columns=False,
    num_train_epochs=100,

    eval_steps=20,
    evaluation_strategy="steps",

    logging_steps=20,
    logging_strategy='steps',

    prediction_loss_only=True,
    
    report_to=None,
)

In [None]:
trainer = transformers.Trainer(
    model=model,
    # data_collator=processor.collate,
    train_dataset=train_batches['train'],
    eval_dataset=train_batches['test'],
    args=trainer_config,
)


trainer.train()