In [1]:
import pandas as pd
from monai.data import PersistentDataset, ImageDataset, DataLoader
from monai.transforms import (
    LoadImaged,
    EnsureChannelFirst,
    EnsureChannelFirstd,
    Compose,
    RandRotate90,
    RandRotate90d,
    Resize,
    Resized,
    ScaleIntensity,
    ScaleIntensityd,
)

from utils import CustomToOneChanneld, CustomToOneChannel, set_up_motor_task
import os
from tqdm import tqdm
from datetime import datetime
import femr
import datasets
import pickle
from networks import DenseNet121_TTE
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [25]:
label_column = '12_month_PH'
nii_folder='./trash/tutorial_TTE/data'
import numpy as np
prop_train = 1
prop_val = 1
use_cachedataset = True
model_save_path='./trash/tutorial_TTE/data'
batch_size = 1

image_paths_train=['./trash/tutorial_TTE/data/831469698_2129-09-04_01_46_00.nii.gz']
labels_train=[0]
pin_memory = False
learning_rate = 1e-6
device='cpu'
use_checkpoint=False
dropout_prob=0.3

# this would take hours to train a tokenizer so we just use the pre-trained one
from_pretrained_tokenizer=True
month_date_hour='022807' # when the pre-trained tokenizer was trained
num_proc=8
label_csv = './trash/tutorial_TTE/data/labels.csv'
ontology_path = './trash/tutorial_TTE/data/ontology.pkl'
parquet_folder = './trash/tutorial_TTE/data/parquet/'
inference=False
vocab_size=512
final_layer_size=512
num_tasks=200 # a subset of vocab given not every patient has all vocab

In [3]:
if use_cachedataset:
    train_transforms = Compose(
        [LoadImaged(keys=["image"]), ScaleIntensityd(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Resized(keys=["image"], spatial_size=(224, 224, 224)), RandRotate90d(keys=["image"]), CustomToOneChanneld(keys=["image"])]
    )
    val_transforms = Compose(
        [LoadImaged(keys=["image"]), ScaleIntensityd(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Resized(keys=["image"], spatial_size=(224, 224, 224)), CustomToOneChanneld(keys=["image"])]
    )
    
else:
    # Define transforms
    train_transforms = Compose(
        [ScaleIntensity(), EnsureChannelFirst(), Resize((224, 224, 224)), RandRotate90(), CustomToOneChannel()]
    )

    val_transforms = Compose(
        [ScaleIntensity(), EnsureChannelFirst(), Resize((224, 224, 224)), CustomToOneChannel()]
    )

In [20]:
def set_up_motor_task(TARGET_DIR, from_pretrained_tokenizer, month_date_hour, num_proc, label_csv, ontology_path, inference, vocab_size, START_TIME, parquet_folder, final_layer_size, num_tasks):
    ###### set up MOTOR task ########
    ################################
    


    from_pretrained = from_pretrained_tokenizer


    if month_date_hour is None:
        month_date_hour = datetime.now().strftime("%m%d%H")
    # if path not exists, create the folder
    if not os.path.exists(TARGET_DIR):
        os.mkdir(TARGET_DIR)
    if not os.path.exists(os.path.join(TARGET_DIR, f'motor_model_{month_date_hour}')):
        os.mkdir(os.path.join(TARGET_DIR, f'motor_model_{month_date_hour}'))


    parquet_folder = os.path.join(parquet_folder, 'data', '*')
    dataset = datasets.Dataset.from_parquet(parquet_folder)


    import femr.index
    print('indexing patients...')
    index = femr.index.PatientIndex(dataset, num_proc=num_proc)
    print('time used indexing patients:', datetime.now() - START_TIME)
    
    inspect_split_csv = os.path.join(TARGET_DIR, f'motor_model_{month_date_hour}', "main_split.csv")

    
    import femr.splits

    main_split = femr.splits.PatientSplit.load_from_csv(inspect_split_csv)

    main_dataset = main_split.split_dataset(dataset, index)
    train_dataset = main_dataset['train']
    print(train_dataset['patient_id'])

    # Note, we need to use a hierarchical tokenizer for MOTOR


    with open(ontology_path, 'rb') as f:
        ontology = pickle.load(f)
    
    # load pretrained tokenizer
    tokenizer = femr.models.tokenizer.FEMRTokenizer.from_pretrained(os.path.join(TARGET_DIR, f'motor_model_{month_date_hour}'), ontology=ontology)

    # 

    time_used = datetime.now() - START_TIME
    print(f"Time used tokenzier: {time_used}")
    
    print("Prefitting MOTOR task...")
    
    with open(os.path.join(TARGET_DIR, f'motor_model_{month_date_hour}', "motor_task.pkl"), 'rb') as f:
        motor_task = pickle.load(f)

    time_used = datetime.now() - START_TIME
    print(f"Time used motor task: {time_used}")
    
    # Third, we need to create batches. 

    processor = femr.models.processor.FEMRBatchProcessor(tokenizer, motor_task)

    index_train = femr.index.PatientIndex(train_dataset, num_proc=num_proc)
    print(index_train.get_patient_ids())

    # We can do this one patient at a time

    time_used = datetime.now() - START_TIME
    print(f"Time used index: {time_used}")
    
   
    
    ###### set up MOTOR task ########
    ################################
    
    return motor_task, tokenizer, train_dataset, None, None, processor, index_train, None, None, num_tasks

In [21]:
if use_cachedataset:
    
    data_train = []
    for i in range(len(image_paths_train)):
        one_entry = {'image': image_paths_train[i]}
        data_train.append(one_entry)
    train_ds = PersistentDataset(
        data=data_train,
        transform=train_transforms,
        cache_dir=os.path.join(model_save_path, 'cache_dir'),
    )

else:
    # create a training data 
    train_ds = ImageDataset(
        image_files=image_paths_train,
        labels=labels_train,
        transform=train_transforms,
    )
    

train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=False, num_workers=3, pin_memory=pin_memory
    )


In [22]:
motor_task, tokenizer, train_dataset, valid_dataset, test_dataset, processor, index_train, index_valid, index_test, num_tasks = set_up_motor_task(model_save_path, from_pretrained_tokenizer, month_date_hour, num_proc, label_csv, ontology_path, inference, vocab_size, datetime.now(), parquet_folder, final_layer_size, num_tasks)

num_proc must be <= 1. Reducing num_proc to 1 for dataset of size 1.


indexing patients...


Map: 100%|██████████| 1/1 [00:00<00:00, 144.37 examples/s]

time used indexing patients: 0:00:00.022646
[831469698]



num_proc must be <= 1. Reducing num_proc to 1 for dataset of size 1.


Time used tokenzier: 0:00:03.851011
Prefitting MOTOR task...
Time used motor task: 0:00:03.851723


Map: 100%|██████████| 1/1 [00:00<00:00, 202.86 examples/s]

dict_keys([831469698])
Time used index: 0:00:03.860437





In [23]:
model = DenseNet121_TTE(spatial_dims=3, in_channels=1, out_channels=2, time_bins=motor_task.time_bins, pretraining_task_info=motor_task.get_task_config().task_kwargs['pretraining_task_info'], final_layer_size=motor_task.final_layer_size, vocab_size=tokenizer.vocab_size, device=device, use_checkpoint=use_checkpoint, dropout_prob=dropout_prob).to(device)

In [26]:
step = -1
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
model.train()
for batch_data in tqdm(train_loader):
    step += 1
    patient_id = int(image_paths_train[step].replace(nii_folder+ "/", '').split('_')[0])
    ct_time = ' '.join(image_paths_train[step].replace(nii_folder+ "/", '').replace('.nii.gz', '').split('_')[1:])
    ct_time = datetime.strptime(ct_time, '%Y-%m-%d %H %M %S')
    for idx, event in enumerate(train_dataset[index_train.get_index(patient_id)]['events']):
        if event['time'] == ct_time:
            offset = idx
            
    example_batch = processor.collate([processor.convert_patient(train_dataset[index_train.get_index(patient_id)], tensor_type='pt', offset=offset, max_patient_length=vocab_size)])
    
    if use_cachedataset:
        inputs = batch_data['image'].to(device)
    else:
        inputs = batch_data[0].to(device)
        
    optimizer.zero_grad()
    loss, _, features = model(inputs, example_batch['batch'], return_logits=False)
    loss.backward()
    optimizer.step()
    print('loss:', loss.item())

100%|██████████| 1/1 [00:17<00:00, 17.36s/it]

loss: 1.739653468132019



