### Retraining using window of most recent encounters

In [1]:
import datetime
import os
import random
import scipy.stats
from datetime import date
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as st
from matplotlib.colors import ListedColormap
from scipy.stats import pearsonr, spearmanr
import pickle
from alibi_detect.cd.pytorch import HiddenOutput
import torch.nn as nn
import torch.optim as optim
from cyclops.utils.file import load_pickle, save_pickle

from cyclops.processors.column_names import (
    ADMIT_TIMESTAMP,
    DISCHARGE_TIMESTAMP,
    ENCOUNTER_ID,
    EVENT_NAME,
    EVENT_TIMESTAMP,
    EVENT_VALUE,
    RESTRICT_TIMESTAMP,
    TIMESTEP,
)

from drift_detection.baseline_models.temporal.pytorch.optimizer import Optimizer, EarlyStopper
from drift_detection.drift_detector.clinical_applicator import ClinicalShiftApplicator
from drift_detection.drift_detector.detector import Detector
from drift_detection.drift_detector.experimenter import Experimenter
from drift_detection.drift_detector.plotter import plot_drift_samples_pval
from drift_detection.drift_detector.reductor import Reductor
from drift_detection.drift_detector.tester import DCTester, TSTester
from drift_detection.gemini.constants import DIAGNOSIS_DICT, ACADEMIC, COMMUNITY, HOSPITALS
from drift_detection.gemini.utils import get_use_case_params, impute, prep, import_dataset_hospital
from models.temporal.utils import (
    get_device,
    load_checkpoint,
)
from cyclops.processors.constants import ALL, FEATURES, MEAN, NUMERIC, ORDINAL, STANDARD
from cyclops.processors.feature.vectorize import (
    Vectorized,
    intersect_vectorized,
    split_vectorized,
    vec_index_exp,
)
from drift_detection.retrainers.periodic import PeriodicRetrainer
from drift_detection.drift_detector.utils import get_serving_data, get_temporal_model

## Get parameters

In [2]:
DATASET = "gemini"
USE_CASE = "mortality"
PATH = "/mnt/nfs/project/delirium/drift_exp/OCT-18-2022/"
TIMESTEPS = 6
AGGREGATION_TYPE = "time"   
ID = SHIFT = input("Select data split: ")
DIAGNOSIS_TRAJECTORY = input("Select diagnosis trajectory to filter on: ") 
HOSPITAL = input("Select hospital to filter on: ") 

splice_map = {
    "hospital_id": HOSPITALS
}

if DIAGNOSIS_TRAJECTORY != "all":
    diagnosis_trajectory = '_'.join(DIAGNOSIS_DICT[DIAGNOSIS_TRAJECTORY])
    ID = ID +"_"+ diagnosis_trajectory
    splice_map["diagnosis_trajectory"] = [diagnosis_trajectory]
    
if HOSPITAL != "all":
    ID = HOSPITAL + "_" + ID 
    splice_map["hospital_id"] = [HOSPITAL]
    
use_case_params = get_use_case_params(DATASET, USE_CASE)
    
MODEL_PATH = os.path.join(
    PATH,
    DATASET,
    USE_CASE,
    "saved_models",
    "simulated_deployment_reweight_positive_lstm_1.pt"
)

Select data split:  simulated_deployment
Select diagnosis trajectory to filter on:  all
Select hospital to filter on:  all


## Get data

In [3]:
random.seed(1)

tab_features = load_pickle(use_case_params.TAB_FEATURES_FILE)

use_case_params = get_use_case_params(DATASET, USE_CASE)

(X_tr_final, y_tr), (X_val_final, y_val), (X_t_final, y_t) = import_dataset_hospital(
    use_case_params.TAB_VEC_COMB, ID, train_frac=0.8, shuffle=True,
)

2023-02-21 16:15:39,259 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/tab_features.pkl
2023-02-21 16:15:39,469 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/aligned_comb_train_X_simulated_deployment.pkl
2023-02-21 16:15:39,894 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/aligned_comb_train_y_simulated_deployment.pkl
2023-02-21 16:15:39,942 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/aligned_comb_val_X_simulated_deployment.pkl
2023-02-21 16:15:40,042 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/aligned_comb_val_y_s

## Create data streams

In [4]:
START_DATE = date(2019, 1, 1)
END_DATE = date(2020, 8, 1)

In [5]:
print("Get target data streams...")
data_streams_filepath = os.path.join(
    PATH,
    DATASET,
    USE_CASE,
    "drift",
    "data_streams_"+ID+"_"+str(START_DATE)+"_"+str(END_DATE)+".pkl"
)
data_streams_filepath

if os.path.exists(data_streams_filepath):
    data_streams = load_pickle(data_streams_filepath)
else:
    tab_vectorized = load_pickle(use_case_params.TAB_VECTORIZED_FILE)
    comb_vectorized = load_pickle(use_case_params.COMB_VECTORIZED_FILE)


    ids = tab_features.slice(splice_map)
    tab_vectorized = tab_vectorized.take_with_index(ENCOUNTER_ID, ids)
    # intersect tabular and temporal vectors of source data
    tab_vectorized, comb_vectorized = intersect_vectorized(
        [tab_vectorized, comb_vectorized], axes=ENCOUNTER_ID
    )
    numeric_features = tab_features.features_by_type(NUMERIC)    
    normalizer_map = {feat: STANDARD for feat in numeric_features}
    comb_vectorized.add_normalizer(
        EVENT_NAME,
        normalization_method=STANDARD,
    )
    X, y = comb_vectorized.split_out(EVENT_NAME, use_case_params.TEMP_TARGETS)
    X = impute(X)
    X.fit_normalizer()
    #X.normalizer = load_pickle("/mnt/nfs/home/subasriv/cyclops/drift_detection/notebooks/mortality/normalizer.pkl")
    X.normalize()

    X_final = prep(X)
    ind = pd.MultiIndex.from_product(
            [X.indexes[1], range(6)], names=[ENCOUNTER_ID, TIMESTEP]
    )
    X_final = pd.DataFrame(X_final.reshape(X_final.shape[0]*X_final.shape[1], X_final.shape[2]),index=ind, columns=X.indexes[2])
    y_final = y.data.reshape(y.data.shape[1],y.data.shape[3])

    data_streams = get_serving_data(
        X_final,
        y_final,
        tab_features.data,
        START_DATE,
        END_DATE,
        stride=1,
        window=1,
        encounter_id="encounter_id",
        admit_timestamp="admit_timestamp",
    )
    save_pickle(data_streams, data_streams_filepath)

2023-02-21 16:15:41,862 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/drift/data_streams_simulated_deployment_2019-01-01_2020-08-01.pkl


Get target data streams...


## Get shift detector

In [9]:
DR_TECHNIQUE = "BBSDs_trained_LSTM"
TESTER_METHOD = "ks"
THRESHOLD = 0.01
UPDATE_REF = 25000

print("Get Shift Reductor...")
reductor = Reductor(
    dr_method=DR_TECHNIQUE,
    model_path=MODEL_PATH,
    n_features=X_tr_final.shape[2],
    var_ret=0.8,
)

tstesters=["lk", "lsdd", "mmd", "tabular", "ctx_mmd","chi2", "fet", "ks" ]
dctesters =["spot_the_diff", "classifier", "classifier_uncertainty"]
CONTEXT_TYPE=None
REPRESENTATION=None

if TESTER_METHOD in tstesters:
    tester = TSTester(
        tester_method=TESTER_METHOD,
    )
    if TESTER_METHOD == "ctx_mmd":
        CONTEXT_TYPE = input("Select context type: ")
        
elif TESTER_METHOD in dctesters:
    MODEL_METHOD = input("Select model method: ")
    tester = DCTester(
        tester_method=TESTER_METHOD,
        model_method=MODEL_METHOD,
    )
        
    if TESTER_METHOD == "lk":
        REPRESENTATION = input("Select learned kernel representation: ")
else:
    print("Tester method invalid or not supported.")
    
print("Get Shift Detector...")
detector = Detector(
    reductor=reductor,
    tester=tester
)

detector.fit(
    X_val_final,
    backend="pytorch",
    device = "cuda",
    model_path = MODEL_PATH,
    batch_size=32,
    verbose=0,
    alternative="greater",
    correction="bonferroni",
    input_dim=X_tr_final.shape[2],
    update_x_ref={'last': UPDATE_REF}
)

Get Shift Reductor...
Loading model...
Get Shift Detector...


## Retrain

In [10]:
retrain = "update"
model_name = "lstm"
output_dim = 1
input_dim = X_tr_final.shape[2]
hidden_dim = 64
layer_dim = 2
dropout = 0.2
last_timestep_only = False
device = get_device()

model_params = {
    "device": device,
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "layer_dim": layer_dim,
    "output_dim": output_dim,
    "dropout_prob": dropout,
    "last_timestep_only": last_timestep_only,
}

if model_name in ["rnn", "gru", "lstm"]:
    model = get_temporal_model(model_name, model_params).to(device)

    if retrain == "update":
        print("Update.")
        model, opt, _ = load_checkpoint(MODEL_PATH, model)
        n_epochs = 1
    elif retrain == "retrain":
        print("Retrain.")
        n_epochs = 64
        learning_rate = 2e-3
        weight_decay = 1e-6
        clipping_value = 1
        reweight_positive = (y_tr == 0).sum() / (y_tr == 1).sum() * 2
        loss_fn = nn.BCEWithLogitsLoss(reduction="none")
        optimizer = optim.Adagrad(
            model.parameters(), lr=learning_rate, weight_decay=weight_decay
        )
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=128, gamma=0.5)
        activation = nn.Sigmoid()
        earlystopper = EarlyStopper(patience=3, min_delta=0)
        opt = Optimizer(
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            activation=activation,
            lr_scheduler=lr_scheduler,
            reweight_positive=reweight_positive,
            earlystopper = earlystopper,
            clipping_value = clipping_value
        )
    else:
        print("Do nothing.")
# elif model_name == "gbt":
#     with open(model_path, "rb") as f:
#         model = pickle.load(f)
else:
    print("Unsupported model")

Update.


In [11]:
SAMPLE = 1000
STAT_WINDOW = 14
RETRAIN_WINDOW = 120
LOOKUP_WINDOW = 0
STRIDE = 1
FREQ = 120
UPDATE_REF=25000

all_runs = []
for i in range(0, 5):
    random.seed(i)
    np.random.seed(i)
    
    detector = Detector(
        reductor=reductor,
        tester=tester
    )

    detector.fit(
        X_val_final,
        backend="pytorch",
        device = "cuda",
        model_path = MODEL_PATH,
        batch_size=32,
        verbose=0,
        alternative="greater",
        correction="bonferroni",
        input_dim=X_tr_final.shape[2],
        update_x_ref={'last': UPDATE_REF}
    )
    
    model, opt, _ = load_checkpoint(MODEL_PATH, model)
    
    retrainer = PeriodicRetrainer(
        shift_detector=detector,
        optimizer=opt,
        model=model,
        model_name=model_name,
    )

    results = retrainer.retrain(
        data_streams=data_streams,
        sample=SAMPLE,
        stat_window=STAT_WINDOW,
        lookup_window=LOOKUP_WINDOW,
        retrain_window=RETRAIN_WINDOW,
        stride=STRIDE,
        model_path=MODEL_PATH,
        freq=FREQ,
        n_epochs=n_epochs,
        correct_only=0,
        positive_only=0,
        verbose=1,
    )
    all_runs.append(results)
    pvals_test = results["p_val"]

    mean = np.mean([i for i in pvals_test if i < THRESHOLD] )
    ci = st.t.interval(
        0.95,
        len([i for i in pvals_test if i < THRESHOLD]) - 1,
        loc=np.mean([i for i in pvals_test if i < THRESHOLD]),
        scale=st.sem([i for i in pvals_test if i < THRESHOLD]),
    )
    print(len([i for i in pvals_test if i < THRESHOLD]), " alarms with avg p-value of ", mean, ci)


  0%|          | 0/534 [00:00<?, ?it/s][A

Calibrating drift detector...



  9%|▉         | 50/534 [00:13<02:06,  3.84it/s][A
 19%|█▊        | 100/534 [00:46<03:23,  2.13it/s][A
 28%|██▊       | 150/534 [01:32<04:10,  1.53it/s][A

Triggered at  2019-05-02 - 2019-05-16 	P-Value:  0.42004663
Retrain  lstm  on:  2019-01-02 - 2019-05-02
[1/1] Training loss: 2.0002	                 Validation loss: 1.8167



 37%|███▋      | 200/534 [02:24<04:25,  1.26it/s][A
 47%|████▋     | 250/534 [03:13<04:04,  1.16it/s][A

Triggered at  2019-08-30 - 2019-09-13 	P-Value:  0.00021613427
Retrain  lstm  on:  2019-05-02 - 2019-08-30
[1/1] Training loss: 1.8579	                 Validation loss: 1.6847



 56%|█████▌    | 300/534 [04:04<03:33,  1.10it/s][A
 66%|██████▌   | 350/534 [04:56<02:56,  1.04it/s][A
 75%|███████▍  | 400/534 [05:55<02:17,  1.03s/it][A

Triggered at  2019-12-28 - 2020-01-11 	P-Value:  0.16019556
Retrain  lstm  on:  2019-08-30 - 2019-12-28
[1/1] Training loss: 1.8456	                 Validation loss: 1.8506



 84%|████████▍ | 450/534 [07:03<01:35,  1.13s/it][A
 94%|█████████▎| 500/534 [08:02<00:38,  1.14s/it][A

Triggered at  2020-04-26 - 2020-05-10 	P-Value:  0.2059353
Retrain  lstm  on:  2019-12-28 - 2020-04-26
[1/1] Training loss: 1.9521	                 Validation loss: 2.0074



550it [09:38,  1.05s/it]                         [A


118  alarms with avg p-value of  0.0017927435 (0.0012644641834446413, 0.002321022722204663)



  0%|          | 0/534 [00:00<?, ?it/s][A

Calibrating drift detector...



  9%|▉         | 50/534 [00:12<02:01,  3.97it/s][A
 19%|█▊        | 100/534 [00:47<03:27,  2.09it/s][A
 28%|██▊       | 150/534 [01:33<04:14,  1.51it/s][A

Triggered at  2019-05-02 - 2019-05-16 	P-Value:  0.42004663
Retrain  lstm  on:  2019-01-02 - 2019-05-02
[1/1] Training loss: 2.0003	                 Validation loss: 1.7942



 37%|███▋      | 200/534 [02:25<04:26,  1.25it/s][A
 47%|████▋     | 250/534 [03:13<04:04,  1.16it/s][A

Triggered at  2019-08-30 - 2019-09-13 	P-Value:  0.00021979728
Retrain  lstm  on:  2019-05-02 - 2019-08-30
[1/1] Training loss: 1.8144	                 Validation loss: 1.6775



 56%|█████▌    | 300/534 [04:06<03:36,  1.08it/s][A
 66%|██████▌   | 350/534 [05:06<03:06,  1.01s/it][A
 75%|███████▍  | 400/534 [06:13<02:29,  1.12s/it][A

Triggered at  2019-12-28 - 2020-01-11 	P-Value:  0.1842064
Retrain  lstm  on:  2019-08-30 - 2019-12-28
[1/1] Training loss: 1.8378	                 Validation loss: 1.8628



 84%|████████▍ | 450/534 [07:31<01:45,  1.25s/it][A
 94%|█████████▎| 500/534 [08:36<00:43,  1.27s/it][A

Triggered at  2020-04-26 - 2020-05-10 	P-Value:  0.19692628
Retrain  lstm  on:  2019-12-28 - 2020-04-26
[1/1] Training loss: 1.9599	                 Validation loss: 1.9565



550it [10:18,  1.12s/it]                         [A


114  alarms with avg p-value of  0.0018560472 (0.001338116286122813, 0.002373978102779374)



  0%|          | 0/534 [00:00<?, ?it/s][A

Calibrating drift detector...



  9%|▉         | 50/534 [00:11<01:55,  4.19it/s][A
 19%|█▊        | 100/534 [00:48<03:29,  2.07it/s][A
 28%|██▊       | 150/534 [01:41<04:38,  1.38it/s][A

Triggered at  2019-05-02 - 2019-05-16 	P-Value:  0.42004663
Retrain  lstm  on:  2019-01-02 - 2019-05-02
[1/1] Training loss: 2.0158	                 Validation loss: 1.7662



 37%|███▋      | 200/534 [02:41<05:00,  1.11it/s][A
 47%|████▋     | 250/534 [03:36<04:36,  1.03it/s][A

Triggered at  2019-08-30 - 2019-09-13 	P-Value:  9.166498e-05
Retrain  lstm  on:  2019-05-02 - 2019-08-30
[1/1] Training loss: 1.8270	                 Validation loss: 1.6777



 56%|█████▌    | 300/534 [04:27<03:50,  1.02it/s][A
 66%|██████▌   | 350/534 [05:24<03:10,  1.04s/it][A
 75%|███████▍  | 400/534 [06:31<02:32,  1.13s/it][A

Triggered at  2019-12-28 - 2020-01-11 	P-Value:  0.10913437
Retrain  lstm  on:  2019-08-30 - 2019-12-28
[1/1] Training loss: 1.8384	                 Validation loss: 1.8632



 84%|████████▍ | 450/534 [07:42<01:42,  1.22s/it][A
 94%|█████████▎| 500/534 [08:46<00:42,  1.24s/it][A

Triggered at  2020-04-26 - 2020-05-10 	P-Value:  0.2187429
Retrain  lstm  on:  2019-12-28 - 2020-04-26
[1/1] Training loss: 1.9994	                 Validation loss: 2.0252



550it [10:31,  1.15s/it]                         [A


120  alarms with avg p-value of  0.0019882608 (0.001445643191067081, 0.002530878356816668)



  0%|          | 0/534 [00:00<?, ?it/s][A

Calibrating drift detector...



  9%|▉         | 50/534 [00:16<02:41,  2.99it/s][A
 19%|█▊        | 100/534 [00:55<04:00,  1.81it/s][A
 28%|██▊       | 150/534 [01:47<04:50,  1.32it/s][A

Triggered at  2019-05-02 - 2019-05-16 	P-Value:  0.42004663
Retrain  lstm  on:  2019-01-02 - 2019-05-02
[1/1] Training loss: 1.9945	                 Validation loss: 1.7957



 37%|███▋      | 200/534 [02:46<05:06,  1.09it/s][A
 47%|████▋     | 250/534 [03:42<04:40,  1.01it/s][A

Triggered at  2019-08-30 - 2019-09-13 	P-Value:  0.00030898795
Retrain  lstm  on:  2019-05-02 - 2019-08-30
[1/1] Training loss: 1.8418	                 Validation loss: 1.6975



 56%|█████▌    | 300/534 [04:40<04:04,  1.05s/it][A
 66%|██████▌   | 350/534 [05:41<03:21,  1.10s/it][A
 75%|███████▍  | 400/534 [06:48<02:37,  1.17s/it][A

Triggered at  2019-12-28 - 2020-01-11 	P-Value:  0.18003738
Retrain  lstm  on:  2019-08-30 - 2019-12-28
[1/1] Training loss: 1.8522	                 Validation loss: 1.8244



 84%|████████▍ | 450/534 [07:59<01:45,  1.25s/it][A
 94%|█████████▎| 500/534 [08:56<00:41,  1.22s/it][A

Triggered at  2020-04-26 - 2020-05-10 	P-Value:  0.18319826
Retrain  lstm  on:  2019-12-28 - 2020-04-26
[1/1] Training loss: 1.9712	                 Validation loss: 1.9521



550it [10:30,  1.15s/it]                         [A


120  alarms with avg p-value of  0.0019707424 (0.0014615894676616248, 0.0024798952593038026)



  0%|          | 0/534 [00:00<?, ?it/s][A

Calibrating drift detector...



  9%|▉         | 50/534 [00:13<02:09,  3.74it/s][A
 19%|█▊        | 100/534 [00:50<03:39,  1.98it/s][A
 28%|██▊       | 150/534 [01:37<04:22,  1.46it/s][A

Triggered at  2019-05-02 - 2019-05-16 	P-Value:  0.42004663
Retrain  lstm  on:  2019-01-02 - 2019-05-02
[1/1] Training loss: 2.0136	                 Validation loss: 1.8037



 37%|███▋      | 200/534 [02:29<04:33,  1.22it/s][A
 47%|████▋     | 250/534 [03:18<04:08,  1.14it/s][A

Triggered at  2019-08-30 - 2019-09-13 	P-Value:  0.00011129042
Retrain  lstm  on:  2019-05-02 - 2019-08-30
[1/1] Training loss: 1.8478	                 Validation loss: 1.7018



 56%|█████▌    | 300/534 [04:12<03:41,  1.06it/s][A
 66%|██████▌   | 350/534 [05:05<03:00,  1.02it/s][A
 75%|███████▍  | 400/534 [06:03<02:18,  1.03s/it][A

Triggered at  2019-12-28 - 2020-01-11 	P-Value:  0.1344901
Retrain  lstm  on:  2019-08-30 - 2019-12-28
[1/1] Training loss: 1.9554	                 Validation loss: 1.8303



 84%|████████▍ | 450/534 [07:10<01:34,  1.13s/it][A
 94%|█████████▎| 500/534 [08:05<00:38,  1.12s/it][A

Triggered at  2020-04-26 - 2020-05-10 	P-Value:  0.14785217
Retrain  lstm  on:  2019-12-28 - 2020-04-26
[1/1] Training loss: 1.9698	                 Validation loss: 2.0140



550it [09:36,  1.05s/it]                         [A

120  alarms with avg p-value of  0.0021184336 (0.0015623217555367033, 0.002674545430145355)





In [12]:
save_pickle(
    all_runs, 
    os.path.join(PATH, DATASET,USE_CASE, "_".join(["periodic",ID,"retrainwindow"+str(RETRAIN_WINDOW),"statwindow"+str(STAT_WINDOW),"lookupwindow"+str(LOOKUP_WINDOW),"update"+str(UPDATE_REF),"epoch"+str(n_epochs),"sample"+str(SAMPLE),"freq"+str(FREQ),"retraining",retrain])+".pkl"),
)

2023-02-21 17:12:46,067 [1;37mINFO[0m cyclops.utils.file - Pickling data to /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/periodic_simulated_deployment_retrainwindow120_statwindow14_lookupwindow0_update25000_epoch1_sample1000_freq120_retraining_update.pkl


'/mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/periodic_simulated_deployment_retrainwindow120_statwindow14_lookupwindow0_update25000_epoch1_sample1000_freq120_retraining_update.pkl'