## Step 1: show dataset, task, model on db list 

In [None]:
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append("../../../")

# ------- user entry to db ----
DATASET = 'mimic iii'
TASK = 'drug recommendation'
MODEL = 'RETAIN'
# -----------------------------

## Step 2: load data and define model

In [None]:
import pyhealth.datasets.datasets as datasets
import pyhealth.models.models as models

# load MIMIC-III
dataset = datasets.MIMIC_III()

# initialize the model and build the dataloaders
dataset.get_dataloader("RETAIN")

model = models.RETAIN(
    dataset=dataset,
    emb_dim=64,
)

## Step 3: send updates to postgres db

In [None]:
from datetime import datetime, timedelta, timezone
sys.path.append("../../")
from pyhealth_web.app import Job, db

RUN_ID = Job.query.count()
TRIGGER_TIME = datetime.now()
RUN_STATS = 'Pending'
example_job = Job(
    run_id = RUN_ID,
    trigger_time = TRIGGER_TIME,
    dataset = DATASET,
    task_name = TASK,
    model = MODEL,
    run_stats = RUN_STATS,
    downloads = "",
)
db.session.add(example_job)
db.session.commit()

## Step 4: train healthcare ML model

In [None]:
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# set trainer with checkpoint
checkpoint_callback = ModelCheckpoint(dirpath='./../data/model_cpt')
trainer = Trainer(
    gpus=1,
    max_epochs=1,
    progress_bar_refresh_rate=5,
    callbacks=[checkpoint_callback],
)

# train model
trainer.fit(
    model=model,
    train_dataloaders=dataset.train_loader,
    val_dataloaders=dataset.val_loader,
)

# test the best model
model.summary(
    output_path='./data',
    test_dataloaders=dataset.test_loader,
    ckpt_path=checkpoint_callback.best_model_path,
)

## Step 5: update the db log

In [None]:
# -------- user enter --------
success = 1
# ----------------------------

if success:
    Job.query.filter_by(run_id=RUN_ID).update(dict(run_stats='SUCCESS', downloads="##"))
else:
    Job.query.filter_by(run_id=RUN_ID).update(dict(run_stats='FAILURE', downloads=""))
db.session.commit()