## Step 1: show dataset, task, model on db list 

In [1]:
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append("../../../")

# ------- user entry to db ----
DATASET = 'mimic iii'
TASK = 'drug recommendation'
MODEL = 'RETAIN'
# -----------------------------

## Step 2: load data and define model

In [2]:
import pyhealth.datasets.datasets as datasets
import pyhealth.models.models as models

# load MIMIC-III
dataset = datasets.MIMIC_III()

# initialize the model and build the dataloaders
dataset.get_dataloader("RETAIN")

model = models.RETAIN(
    dataset=dataset,
    emb_dim=64,
)

100%|█| 39363/39363 [00:18<00:00, 2089.86
100%|█████████████████████████████████████████████████████████████████████████████████████| 46520/46520 [00:20<00:00, 2219.35it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 45354/45354 [00:19<00:00, 2301.92it/s]


generated .pat_to_visit!
generated .visit_dict!
----- preparing code mappings -----
source loaded from https://drive.google.com/uc?id=1I2G6fsBDXDiAK95qFWwtnl3Ib2MaLeCx
source loaded from https://drive.google.com/uc?id=1d2HzsByXrPadvjaKDOEaOt78OkAZOrjC
source loaded from https://drive.google.com/uc?id=199i8mP2gMQNhwUe-2ZNmIr5fhiBbzVlK
source loaded from https://drive.google.com/uc?id=1Z11J4st1sI44jPborls9jIxzcpF-GpGt
mapping finished: RxNorm -> ATC4
mapping finished: ATC4 -> RxNorm
load time: 11.606762409210205s
-----------------------------------------


100%|█████████████████████████████████████████████████████████████████████████████████████| 58976/58976 [00:18<00:00, 3165.83it/s]


generated .encoded_visit_dict!
generated .maps (for code to index mappings)!
load severe ddi pairs from https://drive.google.com/uc?id=1R88OIhn-DbOYmtmVYICmjBSOIsEljJMh!
ddi info is from https://drive.google.com/file/d/1mnPc0O0ztz0fkv3HF-dpmBb8PLWsEoDz/view?usp=sharing!
generated train/val/test dataloaders for RETAIN model!


## Step 3: send updates to postgres db

In [3]:
from datetime import datetime, timedelta, timezone
sys.path.append("../../")
from pyhealth_web.app import Job, db

RUN_ID = Job.query.count()
TRIGGER_TIME = datetime.now()
RUN_STATS = 'Pending'
example_job = Job(
    run_id = RUN_ID,
    trigger_time = TRIGGER_TIME,
    dataset = DATASET,
    task_name = TASK,
    model = MODEL,
    run_stats = RUN_STATS,
    downloads = "",
)
db.session.add(example_job)
db.session.commit()

## Step 4: train healthcare ML model

In [None]:
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# set trainer with checkpoint
checkpoint_callback = ModelCheckpoint(dirpath='./../data/model_cpt')
trainer = Trainer(
    gpus=1,
    max_epochs=1,
    progress_bar_refresh_rate=5,
    callbacks=[checkpoint_callback],
)

# train model
trainer.fit(
    model=model,
    train_dataloaders=dataset.train_loader,
    val_dataloaders=dataset.val_loader,
)

# test the best model
model.summary(
    output_path='./data',
    test_dataloaders=dataset.test_loader,
    ckpt_path=checkpoint_callback.best_model_path,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /mnt/disks/ssd_new/github/PyHealth-OMOP/pyhealth_web/jupyter-pool/pool/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type       | Params
-----------------------------------------
0 | embedding | Sequential | 517 K 
1 | alpha_gru | GRU        | 25.0 K
2 | beta_gru  | GRU        | 25.0 K
3 | alpha_li  | Linear     | 65    
4 | beta_li   | Linear     | 4.2 K 
5 | output    | Linear     | 10.9 K
-----------------------------------------
582 K     Trainable params
28.2 K    Non-trainable params
611 K     Total params
2.445     Total estimated model params size (MB)


Epoch 0:  16%|█████████▎                                                  | 835/5368 [00:08<00:44, 101.07it/s, loss=1.37, v_num=0]

## Step 5: update the db log

In [None]:
# -------- user enter --------
success = 1
# ----------------------------

if success:
    Job.query.filter_by(run_id=RUN_ID).update(dict(run_stats='SUCCESS', downloads="##"))
else:
    Job.query.filter_by(run_id=RUN_ID).update(dict(run_stats='FAILURE', downloads=""))
db.session.commit()