# The Replication and Extension of "DuETT: Dual Event Time Transformer for Electronic Health Records"
**Team #73 (Solo): Runhua Yang (runhua2@illinois.edu) UIN: 678314629**

# Summary and Critical Findings (so far)

The DuETT model is purported to be the state-of-the-art (SOTA) for the MIMIC-IV 2.0 and PhysioNet-2012 datasets in terms of mortality prediction and phenotype classification tasks. The authors have made the source code for the DuETT model, along with the training and evaluation pipelines for PhysioNet-2012, publicly available. However, due to licensing restrictions, the pipeline for MIMIC-IV 2.0 is not provided, posing challenges for reproducibility. I suspect that licensing concerns may not be the sole reason preventing the publication of the data pipeline code. Additionally, the training and evaluation code for the MIMIC-IV dataset has not been provided, despite the authors' claim of their model's strong generalizability. 

While the data processing for PhysioNet-2012 is relatively straightforward and standardized, integrating with the torchtime.data library, the situation is less clear for the MIMIC-IV benchmark. The paper's methodology is based on an implementation from MIMIC-III but includes additional features and engineering, which complicates the process of benchmarking. __Based on the sparse information from the paper, a considerable amount of my time was dedicated to developing a data pipeline for MIMIC-IV 2.0, which involved modifying 14 files with 500 lines of code added and 69 lines removed. Furthermore, I worked on the training and evaluation code for MIMIC-IV 2.0, adding 2 files with 386 lines of code.__

I deliberately avoided requesting code or details from the authors that were not disclosed initially, focusing instead on determining the true factors influencing the final performance through my own trial runs. A pertinent question arises: Did the authors exert the same level of effort in feature inclusion (85 chart event variables and 29 lab event variables added for ICU mortality) and processing when comparing their model to other approaches, such as XGBoost? Without comprehensive and equivalent feature handling, it becomes difficult to assert that the DuETT model is unequivocally superior to other methodologies. I remain committed to further investigation into this matter.


# Code Repo

- Model/Training/Eval: https://github.com/ryuiuc/dlh-model

- Data Pipeline: https://github.com/ryuiuc/dlh-data

- Link to .ipynb file: https://github.com/ryuiuc/dlh-model/blob/master/dlh_TEAM73_UIN678314629.ipynb

# Paper for Discussion

Labach, A., Pokhrel, A., Huang, X.S., Zuberi, S., Yi, S.E., Volkovs, M.,
Poutanen, T., & Krishnan, R.G. (2023). **DuETT: Dual Event Time Transformer for
Electronic Health Records.** Proceedings of Machine Learning Research, 219:1–26,
2023.

# Introduction


## Background of the problem
1. What type of problem: The problem addressed in the paper revolves around the effective modeling and analysis of Electronic Health Records (EHR) data for various clinical predictions and insights. Specifically, the paper focuses on the development of a machine learning model capable of handling the unique challenges posed by EHR data, such as high sparsity, irregular observations, and the need to leverage the temporal and event-type dimensions for accurate predictions.

2. Importance/Meaning of solving the problem: Solving this problem is crucial for improving patient care and hospital operations. Accurate predictions based on EHR data can lead to timely interventions, reduced readmissions, and better management of healthcare resources. Moreover, understanding the structure and patterns within EHR data can provide valuable insights into patient health trends and inform clinical decision-making processes.

3. Difficulty of the problem: The complexity of EHR data, characterized by its high-dimensionality, sparsity, and irregularity, presents significant challenges for traditional time series analysis methods. The need for a model that can effectively capture the relationships between different types of observations and the temporal dynamics of patient health adds to the difficulty.

4. State of the art methods and effectiveness: Prior to the proposed method, various neural network models, including recurrent neural networks (RNNs) and gradient boosting models like XGBoost, have been applied to EHR data with varying degrees of success. However, these models often struggle to fully exploit the structured relationships within the data, especially when it comes to multivariate time series with missing values.
## Paper explanation
1. What did the paper propose: The paper introduces DuETT (Dual Event Time Transformer), a novel architecture that extends the capabilities of Transformers to handle both time and event type dimensions in EHR data. DuETT is designed to provide robust representations from EHR data by transforming sparse time series into a regular sequence with fixed length, thus reducing computational complexity and enabling the use of larger and deeper neural networks.

2. Innovations of the method: The key innovations of DuETT include its ability to attend over both time and event type dimensions, the use of self-supervised learning (SSL) for model pre-training, and the design of an input representation that incorporates event information, static variables, and aggregates observations in a way that is computationally efficient. The paper also proposes a novel SSL training scheme that performs masked modeling of measured event values and missingness across both dimensions.

3. How well the proposed method works (in its own metrics): The proposed DuETT model outperforms state-of-the-art deep learning models on multiple downstream tasks from the MIMIC-IV and PhysioNet-2012 EHR datasets. It demonstrates superior performance in tasks such as mortality prediction and phenotype classification, showcasing its effectiveness in learning from EHR data.

4. Contribution to the research regime: As stated in the paper, the authors believe that their core contributions are:
   
    (1) the novel DuETT architecture design, which extends Transformers to exploit both time and event modalities of EHR data.

    (2) The design of input representation.

    (3) A novel self-supervised training scheme that performs masked modelling of measured event values and missingness across both time and event dimensions.

    (4) A thorough empirical evaluation of the approach on the MIMIC-IV (Johnson et al., 2022) and PhysioNet-2012 (Silva et al., 2012) hospital EHR datasets, demonstrating state-of-the-art performances (beating XGBoost specifically).
NOTE: This paper primarily focuses on the classification and prediction of event values within the modeling framework, not addressing the time-to-event forecasting aspect, which is a popular research topic and application area using similar datasets. Additionally, the scope of this study is restricted to numeric data inputs; it does not capitalize on the potential of textual data and other diverse data modalities.



# Scope of Reproducibility

1. The DuETT model can be sufficiently and accurately replicated to reproduce the original results, thereby validating its effectiveness in handling the complexities of EHR data. The goal is to achieve a comparable level of performance, thereby substantiating its status as state-of-the-art.
    Owing to constrained computational resources and time, the following components will be addressed.
    - MIMIC-IV 2.0 ICU Mortality Task
    - PhysioNet-2012 Mortaliy
    - Excluding:
        - MIMIC-IV ED Transfer to ICU task
        - MIMIC-IV ICU Phenotyping
          
3. The outcomes of the ablation study should correspond with the findings
reported in the paper. Ablations to be covered (progressing, till end of project):
    - Attention Mechanisms: Due to limitations in time and resources, this project will focus on replicating the "Time Transformer only" experiment, which
experienced the most significant decrease in performance.
    - Self-Supervised Learning: Due to limitations in time and resources, this project will focus on replicating the "Event type masking only" experiment, which experienced the most significant decrease in performance. Note: This suggests that there is potential for enhancing the current masking approach, particularly along the time dimension, which could lead to a boost in the model's overall performance.
    - The remaining ablation studies presented in the paper are logical and align with expectations; however, due to constraints in time and resources, they will not be verified as part of this project.
   ![](https://uiuc-dlh.s3.amazonaws.com/duett_ablations.png)

# Methodology

## Data

### MIMIC-IV 2.0 Data:
    - Obtain license and download
    - Run my below code to pre-process the data and generate input data for ICU Mortality task (Please follow README.md from the below repo)

In [4]:
import os
os.environ['GITHUB_TOKEN'] = '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'  # Set this securely, perhaps in a configuration file or environment variable outside of the notebook

# Use the token securely without exposing it in the notebook
!git clone https://ryuiuc:${GITHUB_TOKEN}@github.com/ryuiuc/dlh-data mimic-iv-benchmarks

Cloning into 'mimic-iv-benchmarks'...
remote: Enumerating objects: 396, done.[K
remote: Counting objects: 100% (396/396), done.[K
remote: Compressing objects: 100% (338/338), done.[K
remote: Total 396 (delta 69), reused 366 (delta 53), pack-reused 0[K
Receiving objects: 100% (396/396), 27.32 MiB | 22.27 MiB/s, done.
Resolving deltas: 100% (69/69), done.
Updating files: 100% (73/73), done.


**Input Data Pipeline**
1. Clone the repo.
```
       git clone https://github.com/ryuiuc/dlh-data/
       cd mimic4-benchmarks/
```    
2. The following command takes MIMIC-IV CSVs, generates one directory per `SUBJECT_ID` and writes ICU stay information to `data/{SUBJECT_ID}/stays.csv`, diagnoses to `data/{SUBJECT_ID}/diagnoses.csv`, and events to `data/{SUBJECT_ID}/events.csv`. This step might take around an hour.
```
       python -m mimic4benchmark.scripts.extract_subjects ./mimic-iv data/root/
```
3. The following command attempts to fix some issues (ICU stay ID is missing) and removes the events that have missing information. About 80% of events remain after removing all suspicious rows (more information can be found in [`mimic4benchmark/scripts/more_on_validating_events.md`](mimic4benchmark/scripts/more_on_validating_events.md)).
```
       python -m mimic4benchmark.scripts.validate_events data/root/
```
4. The next command breaks up per-subject data into separate episodes (pertaining to ICU stays). Time series of events are stored in ```{SUBJECT_ID}/episode{#}_timeseries.csv``` (where # counts distinct episodes) while episode-level information (patient age, gender, ethnicity, height, weight) and outcomes (mortality, length of stay, diagnoses) are stores in ```{SUBJECT_ID}/episode{#}.csv```. This script requires two files, one that maps event ITEMIDs to clinical variables and another that defines valid ranges for clinical variables (for detecting outliers, etc.). **Outlier detection is disabled in the current version**.
```
       python -m mimic4benchmark.scripts.extract_episodes_from_subjects data/root/
```
5. The next command splits the whole dataset into training and testing sets. Note that the train/test split is the same of all tasks.
```
       python -m mimic4benchmark.scripts.split_train_and_test data/root/
```
6. The following commands will generate task-specific datasets, which can later be used in models. These commands are independent, if you are going to work only on one benchmark task, you can run only the corresponding command.
```
       python -m mimic4benchmark.scripts.create_in_hospital_mortality data/root/ data/in-hospital-mortality/
```
After the above commands are done, there will be a directory `data/{task}` for in-hospital mortality task.
These directories have two sub-directories: `train` and `test`.
Each of them contains bunch of ICU stays and one file with name `listfile.csv`, which lists all samples in that particular set.

In [9]:
# cd mimic-iv-benchmarks; run below command line in the terminal. Assuming MIMIC-IV 2.0 data is under mimic-iv/2.0/hosp/
# nohup sh -c "python -m mimic4benchmark.scripts.extract_subjects mimic-iv/2.0/hosp/ data/root/ && python -m mimic4benchmark.scripts.validate_events data/root/ && python -m mimic4benchmark.scripts.extract_episodes_from_subjects data/root/ && python -m mimic4benchmark.scripts.split_train_and_test data/root/ && python -m mimic4benchmark.scripts.create_in_hospital_mortality data/root/ data/in-hospital-mortality/" > nohup_output.log 2>&1 &

START:
	ICUSTAY_IDs: 76943
	HADM_IDs: 69639
	SUBJECT_IDs: 53569
REMOVE PATIENTS AGE < 18:
	ICUSTAY_IDs: 76943
	HADM_IDs: 69639
	SUBJECT_IDs: 53569
FILTER FOR PATIENTS WITH EVENT RECORDS:
	Before filtering: 76943 stays
	After filtering: 76935 stays
	Number of stays removed: 8
FILTER FOR PATIENTS WITH mimic-iv-patient-split.json:
	Before filtering: 76935 stays
	After filtering: 76935 stays
	Number of stays removed: 0
Breaking up stays by subjects: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 53569/53569 [06:22<00:00, 140.09it/s]
Breaking up diagnoses by subjects: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 53569/53569 [05:28<00:00, 163.16it/s]
Processing CHARTEVENTS table: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 329822285/329822285 [55:09<00:00, 99667.55it/s]
Processing LABEVENTS table: 100%|██

In [11]:
# Sample input data below
#!head /home/ubuntu/duett/mimic-iv-benchmarks/data/in-hospital-mortality/test/19872834_episode1_timeseries.csv

Hours,18 Gauge Dressing Occlusive,18 Gauge placed in outside facility,20 Gauge Dressing Occlusive,20 Gauge placed in outside facility,20 Gauge placed in the field,ART BP Diastolic,ART BP Systolic,Admission Weight (lbs.),Alarms On,Ambulatory aid,Anion Gap,Anion gap,Arterial Blood Pressure diastolic,Arterial Blood Pressure systolic,BUN,Base Excess,Bicarbonate,Braden Activity,Braden Friction/Shear,Braden Mobility,Braden Moisture,Braden Nutrition,Braden Sensory Perception,Calcium non-ionized,"Calcium, Total",Calculated Total CO2,Capillary Refill L,Capillary Refill R,Chloride,Chloride (serum),Creatinine,Creatinine (serum),Currently experiencing pain,Daily Weight,Dialysis patient,Difficulty swallowing,ETOH,Eye Care,Fraction inspired oxygen,GCS - Eye Opening,GCS - Motor Response,GCS - Verbal Response,GLUCOSE,Gait/Transferring,Glucose (serum),Glucose (whole blood),Glucose (whole blood) (soft),Glucose finger stick (range 70-100),Goal Richmond-RAS Scale,HCO3 (serum),HEMOGLOBIN,Heart Rate,Heart R

Above columns fully compliant with paper's design. NOTE: This is before categorical one-hot encoding.

Upon conducting a sanity check, the train and test sets for the ICU mortality task were found to **perfectly** replicate the figures stated in the paper. (Please note: when referring to "instances," The paper is indicating individual patients, while "samples" denote valid episodes.)
```
For the ICU mortality task, our training set consists of a total of 19,699 instanceswith a positive mortality rate of 12.95%, our validation set contains 4,257 instances with amortality rate of 13.55%, and our test set contains 4,245 instances with a mortality rate of12.39%.
```

#### Here's what I accomplished, with all code submitted to GitHub:

    1. Adapted the code base to work with MIMIC-IV 2.0, building upon the existing code from MIMIC-III and MIMIC-IV 1.0.
    2. Conducted feature engineering as per the guidelines outlined in the paper, incorporating 85 chart event variables and 29 lab event variables.
    3. Ensured alignment with the filtering logic described.
    4. Applied one-hot encoding to handle categorical variables appropriately.
    5. Warped class MIMICIVDataset(Dataset) and class MIMICIVDataModule(pl.LightningDataModule) for PyTorch Lightening training later.

In [None]:
import gc

import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
from multiprocessing import Manager
import pytorch_lightning as pl

class MIMICIVDataset(Dataset):
    def __init__(self, data_dir, split_name, n_timesteps=32, use_temp_cache=False, **kwargs):
        if split_name == 'val':
            self.data_dir = os.path.join(data_dir, 'train')
        else:
            self.data_dir = os.path.join(data_dir, split_name)
        self.split_name = split_name
        self.n_timesteps = n_timesteps
        self.temp_cache = Manager().dict() if use_temp_cache else None
        self.train_prop = 0.7
        self.val_prop = 0.15
        self.X = self.y = None
        self.predefined_categories = {
            #"Braden Activity": set(range(1, 5)),
            #"Braden Friction/Shear": set(range(1, 4)),
            #"Braden Mobility": set(range(1, 5)),
            #"Braden Moisture": set(range(1, 5)),
            #"Braden Nutrition": set(range(1, 5)),
            #"Braden Sensory Perception": set(range(1, 5)),
            "GCS - Eye Opening": set(range(1, 5)),
            "GCS - Motor Response": set(range(1, 7)),
            "GCS - Verbal Response": set(range(0, 6)),  # Including "No Response-ETT" as 0
            #"Goal Richmond-RAS Scale": set(range(-5, 1)),  # Negative and zero
            #"Pain Level": set(range(0, 9)),  # Includes "Unable to Score"
            #"Pain Level Response": set(range(0, 9)),  # Includes "Unable to Score"
            #"Richmond-RAS Scale": set(range(-5, 5)),  # Negative through positive
            #"Strength L Arm": set(range(0, 6)),
            #"Strength L Leg": set(range(0, 6)),
            #"Strength R Arm": set(range(0, 6)),
            #"Strength R Leg": set(range(0, 6)),
            #"Ambulatory aid": set(range(0, 8)),  # Including "Furniture" as 7
            "Capillary Refill L": set(range(1, 3)),
            "Capillary Refill R": set(range(1, 3)),
            #"Gait/Transferring": set(range(1, 6)),
            #"History of falling (within 3 mnths)": set(range(0, 2)),  # Yes or No
            #"IV/Saline lock": set(range(0, 2)),  # Yes or No
            "Mental status": set(range(1, 3)),
            "Marital Status": set(range(1, 7)),  # Includes '' as 6
            "Insurance": set(range(1, 6)),  # Includes '' as 5
            "Admission Location": set(range(1, 14)),  # Includes '' as 13
            "Admission Type": set(range(1, 12)),  # Includes '' as 11
            "Ethnicity": set(range(0, 5)),  # Includes multiple ethnic groups and '' as 0
            "First Care Unit": set(range(1, 12))  # Includes '' as 11
        }


    def setup(self):
        # Load the list of stays
        listfile = os.path.join(self.data_dir, "listfile.csv")

        stay_list = pd.read_csv(listfile)
        # Randomly shuffle the DataFrame
        stay_list = stay_list.sample(frac=1, random_state=2020)  # Use a seed for reproducibility

        # Calculate split indices
        num_stays = len(stay_list)
        num_val = int(num_stays * self.val_prop / (self.val_prop + self.train_prop))  # First part for validation
        num_train = num_stays - num_val  # Rest part for training

        if self.split_name == 'val':
            # Use the first part for validation
            stay_list = stay_list.iloc[:num_val]
        elif self.split_name == 'train':
            # Use the rest part for training
            stay_list = stay_list.iloc[num_val:num_stays]
        else:
            # If split_name is not 'train' or 'val', no slicing is needed
            pass
        #stay_list = stay_list.iloc[:1000]

        timeseries_data = []
        labels = []

        # Load data for each stay
        for _, row in tqdm(stay_list.iterrows(), total=stay_list.shape[0], desc=f'Loading {self.split_name} data'):
            stay_id, label = row['stay'], row['y_true']
            ts_filename = os.path.join(self.data_dir, stay_id)

            # Read timeseries data
            ts_data = pd.read_csv(ts_filename)
            #ts_data = ts_data.iloc[:, :self.d_time_series_num()]
            #pd.set_option('display.max_rows', None)
            # print(ts_data.dtypes)
            ts_data = ts_data.apply(pd.to_numeric, errors='coerce')
            for column, categories in self.predefined_categories.items():
                nan_category = -100
                ts_data[column] = ts_data[column].fillna(nan_category).astype(int)

                categories_with_nan = categories.union({nan_category})
                ts_data[column] = pd.Categorical(ts_data[column], categories=categories_with_nan)

                # Create dummy/one-hot encoded variables
                dummies = pd.get_dummies(ts_data[column], prefix=column)

                # Find the original column index
                col_index = ts_data.columns.get_loc(column)

                # Drop the original column
                ts_data.drop(columns=[column], inplace=True)

                # Concatenate data: part before the column, dummies, part after the column
                first_part = ts_data.iloc[:, :col_index]
                second_part = ts_data.iloc[:, col_index:]

                # Concatenate all parts together
                ts_data = pd.concat([first_part, dummies, second_part], axis=1)

            #pd.set_option('display.max_rows', None)
            #pd.set_option('display.max_columns', None)
            #for col in ts_data.columns:
            #    print(col)
            #print("col len=",len(ts_data.columns))
            #exit(1)
            # Store data
            timeseries_data.append(torch.tensor(ts_data.values, dtype=torch.float32))
            labels.append(label)

        max_length = 1250
        preprocessed_data = []
        min_padding_length = 1250
        # Pad sequences with NaN and store them
        for ts_data in timeseries_data:
            # Calculate how much padding is needed
            padding_length = max_length - ts_data.shape[0]
            if padding_length < min_padding_length:
                min_padding_length = padding_length
            # Pad the sequence with NaNs if necessary
            if padding_length > 0:
                padding = torch.full((padding_length, ts_data.shape[1]), float('nan'), dtype=torch.float32)
                ts_data_padded = torch.cat((ts_data, padding), dim=0)
            else:
                ts_data_padded = ts_data

            preprocessed_data.append(ts_data_padded)
        print(f'max_length = 1250, min_padding_length={min_padding_length}')
        # self.X = torch.stack(timeseries_data)
        self.X = torch.stack(preprocessed_data)
        self.y = torch.tensor(labels, dtype=torch.long).unsqueeze(1)

        self.means = []
        self.stds = []
        self.maxes = []
        self.mins = []
        for i in range(self.X.shape[2]):
            vals = self.X[:,:,i].flatten()
            vals = vals[~torch.isnan(vals)]
            if vals.numel() > 0:
                self.means.append(vals.mean())
                self.stds.append(vals.std())
                self.maxes.append(vals.max())
                self.mins.append(vals.min())
            else:
                self.means.append(0)
                self.stds.append(1)
                self.maxes.append(float('nan'))
                self.mins.append(float('nan'))
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, i):
        if self.temp_cache is not None and i in self.temp_cache:
            return self.temp_cache[i]

        ins = self.X[i, ~torch.isnan(self.X[i, :, 0]), :]
        time = ins[:, 0] / 24  # input is HOURS
        x_static = torch.zeros(self.d_static_num())

        x_ts = torch.zeros((self.n_timesteps, self.d_time_series_num() * 2))
        for i_t, t in enumerate(time):
            bin = self.n_timesteps - 1 if t == time[-1] else int(t / time[-1] * self.n_timesteps)
            for i_ts in range(1, self.d_time_series_num()+1):
                x_i = ins[i_t, i_ts]
                if not torch.isnan(x_i).item():
                    x_ts[bin, i_ts - 1] = (x_i - self.means[i_ts]) / (self.stds[i_ts] + 1e-7)
                    x_ts[bin, i_ts - 1 + self.d_time_series_num()] += 1
        bin_ends = torch.arange(1, self.n_timesteps + 1) / self.n_timesteps * time[-1]

        for i_tab in range(self.d_time_series_num()+1, self.d_time_series_num()+self.d_static_num()+1):
            x_i = ins[0, i_tab]
            x_i = (x_i - self.means[i_tab]) / (self.stds[i_tab] + 1e-7)
            x_static[i_tab - self.d_time_series_num()-1] = x_i.nan_to_num(0.)

        x = (x_ts, x_static, bin_ends)
        y = self.y[i, 0]
        if self.temp_cache is not None:
            self.temp_cache[i] = (x, y)

        return x, y

    def d_static_num(self):
        return 61

    def d_time_series_num(self):
        return 161

    def d_target(self):
        return 1

    def pos_frac(self):
        return torch.mean(self.y.float()).item()

def collate_into_seqs(batch):
    xs, ys = zip(*batch)
    return zip(*xs), ys
class MIMICIVDataModule(pl.LightningDataModule):
    def __init__(self, data_path='./mimic-iv-benchmarks/data/in-hospital-mortality/', use_temp_cache=False, batch_size=8, num_workers=1, prefetch_factor=2,
            verbose=0, **kwargs):
        self.use_temp_cache = use_temp_cache
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor

        self.data_path = data_path
        self.ds_train = MIMICIVDataset(self.data_path,'train', use_temp_cache=use_temp_cache)
        self.ds_val = MIMICIVDataset(self.data_path,'val', use_temp_cache=use_temp_cache)
        self.ds_test = MIMICIVDataset(self.data_path,'test', use_temp_cache=use_temp_cache)

        self.prepare_data_per_node = False

        self.dl_args = {'batch_size': self.batch_size, 'prefetch_factor': self.prefetch_factor,
                        'collate_fn': collate_into_seqs, 'num_workers': num_workers}

    def setup(self, stage=None):
        if stage is None:
            self.ds_train.setup()
            self.ds_val.setup()
            self.ds_test.setup()
        elif stage == 'fit':
            self.ds_train.setup()
            self.ds_val.setup()
        elif stage == 'validate':
            self.ds_val.setup()
        elif stage == 'test':
            self.ds_test.setup()
    def prepare_data(self):
        pass

    def _log_hyperparams(self):
        pass

    def train_dataloader(self):
        return DataLoader(self.ds_train, shuffle=True, **self.dl_args)

    def val_dataloader(self):
        return DataLoader(self.ds_val, **self.dl_args)

    def test_dataloader(self):
        return DataLoader(self.ds_test, **self.dl_args)

    def d_static_num(self):
        return self.ds_train.d_static_num()

    def d_time_series_num(self):
        return self.ds_train.d_time_series_num()

    def d_target(self):
        return self.ds_train.d_target()

    def pos_frac(self):
        return self.ds_train.pos_frac()

### PhysioNet-2012 Data: `torchtime.data.PhysioNet2012(self.split_name, train_prop=0.7, val_prop=0.15, time=False, seed=0)`

### Time-series Binning Strategy: `Refer to paper 3.1 Data/Input Binning`
  ```
        x_ts = torch.zeros((self.n_timesteps, self.d_time_series_num()*2))
        for i_t, t in enumerate(time):
            bin = self.n_timesteps - 1 if t == time[-1] else int(t / time[-1] * self.n_timesteps)
            for i_ts in range(1,37):
                x_i = ins[i_t,i_ts]
                if not torch.isnan(x_i).item():
                    x_ts[bin, i_ts-1] = (x_i - self.means[i_ts])/(self.stds[i_ts] + 1e-7)
                    x_ts[bin, i_ts-1+self.d_time_series_num()] += 1
        bin_ends = torch.arange(1, self.n_timesteps+1) / self.n_timesteps * time[-1]
  ```

## Model

### Model Descriptions/Architecture:(quoted from paper)
```
The overall structure of our DuETT model is a series of DuETT layers followed by classification or self-supervised learning heads. Each DuETT layer is made up of two Transformer sublayers that attend along the event and time dimensions respectively. The first sublayer consists of multi-head attention over events followed by a feed-forward network operating along the event dimension, which can be collectively identified as an event transformer layer; the second sublayer consists of multi-head attention over time bins followed by a feed-forward network operating along the time dimension, the time transformer layer. The dual attention architecture enables our model to capture the two important modalities of EHR data, namely the types of events that are observed for a given patient and the times at which they are observed. Event-type and time bin embeddings are injected just before their respective sublayers. Embedding injections are done throughout the entire network, rather than just before the first layer, to ensure access and to emphasize the ordering information of data, especially in upper layers.
```
- **Layer Number/Size/Type**: DuETT model uses 2 DuETT layers, with a total of 4 Transformer sublayers. The Transformers have an internal feedforward dimension of 512. The classification head has one hidden layer of size 64 and batch normalization after the hidden layer. The static data encoder has one hidden layer of size 128 and batch normalization after the hidden layer. The implementation uses 32 time steps.

  Below is the detailed breakdown of the actual layers and their respective number of parameters. Please note that minor discrepancies may arise depending on the treatment of certain columns as categorical, which could result in an expansion of the input dimensions.
```
   | Name                         | Type             | Params
-------------------------------------------------------------------
0  | special_embeddings           | Embedding        | 192   
1  | embedding_layers             | ModuleList       | 302 K 
2  | n_obs_embedding              | Embedding        | 16    
3  | event_transformers           | ModuleList       | 1.8 M 
4  | full_event_embedding         | Embedding        | 128 K 
5  | time_transformers            | ModuleList       | 8.7 M 
6  | full_time_embedding          | Sequential       | 245 K 
7  | full_rep_embedding           | Embedding        | 3.9 K 
8  | head                         | Sequential       | 249 K 
9  | pretrain_value_proj          | Sequential       | 626 K 
10 | pretrain_presence_proj       | Sequential       | 626 K 
11 | predict_events_proj          | Sequential       | 25.4 K
12 | predict_events_presence_proj | Sequential       | 25.4 K
13 | tab_encoder                  | Sequential       | 11.3 K
14 | train_auroc                  | AUROC            | 0     
15 | val_auroc                    | AUROC            | 0     
16 | train_ap                     | AveragePrecision | 0     
17 | val_ap                       | AveragePrecision | 0     
18 | test_auroc                   | AUROC            | 0     
19 | test_ap                      | AveragePrecision | 0     
-------------------------------------------------------------------
12.7 M    Trainable params
0         Non-trainable params
12.7 M    Total params
50.953    Total estimated model params size (MB)
```
![](https://uiuc-dlh.s3.amazonaws.com/duett_arch.png)


- **Activation Function**: The model employs activation functions such as ReLU or other non-linear functions to introduce non-linearity into the model, enabling it to capture complex patterns within the data.
- **Training Objectives**:

    - **Loss Function**: The training process involves a self-supervised learning (SSL) pre-training stage followed by a supervised fine-tuning stage. During SSL, the model is trained with pseudo-tasks that produce robust representations without the need for explicit labels. The loss function for pre-training includes both value and presence losses to capture the clinical priors and the sparsity structure of the EHR data.

      ![](https://uiuc-dlh.s3.amazonaws.com/duett_loss.png)
    - **Optimizer**: The model uses an optimizer like AdamW for stable and efficient training. The learning rate is scheduled with linear warmup followed by inverse square-root decay.
    - **Weight of Each Loss Term**: The weights of each loss term are determined through hyperparameter tuning to balance the contributions of the value and presence predictions.

- **Pretrained Model**: The DuETT model benefits from SSL pre-training, which allows it to learn useful representations from the data without relying on labeled data. This is particularly advantageous for EHR data where labeled samples may be scarce. Itperforms self-supervised pre-training for 300 epochs using AdamW.
- **Fine tuning**: It fine-tune DuETT for 30 epochs for MIMIC-IV and 50 epochs for PhysioNet.

### Model Code

[`duett.py`](https://github.com/layer6ai-labs/DuETT/blob/master/duett.py) by original author of paper.

## Training

### Training Setup
```
1x A10 (24 GB PCIe)
30 vCPUs, 200 GiB RAM, 1.4 TiB SSD
```
available on: [Lambda Labs](https://cloud.lambdalabs.com/)
NOTE: Google Colab is not used. Classic Jupyter Notebook and terminal are used.

**Total Cost(till Apr 14):~3500USD****


### Training Code

#### Training MIMIC-IV ICU Mortality


In [1]:
import os
os.environ['GITHUB_TOKEN'] = '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'  # Set this securely, perhaps in a configuration file or environment variable outside of the notebook

# Use the token securely without exposing it in the notebook
!git clone https://ryuiuc:${GITHUB_TOKEN}@github.com/ryuiuc/dlh-model DuETT

Cloning into 'DuETT'...
remote: Enumerating objects: 17, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 17 (delta 6), reused 14 (delta 3), pack-reused 0[K
Receiving objects: 100% (17/17), 1.55 MiB | 10.49 MiB/s, done.
Resolving deltas: 100% (6/6), done.


In [27]:
# If your current ipykernel is not on Python 3.9, do the below in a terminal first
# > python3.9 -m venv DuETT/myenv
# > source DuETT/myenv/bin/activate
# > pip install ipykernel
# > python -m ipykernel install --user --name=DuETT --display-name "Python 3.9 (DuETT)"
# Then let Jupyter Notebook use the kernel named "Python 3.9 (DuETT)" before continue.

In [None]:
import sys
!{sys.executable} -m pip install -r DuETT/requirements.txt

In [2]:
import os
os.chdir('DuETT')

# Confirm the change
print("Current working directory is now:", os.getcwd())

Current working directory is now: /home/ubuntu/duett/dlh_submission/DuETT


##### Download Pre-trained Model

Pre-train code is at below which is commented out. 

Based on current pre-train progress, the best MIMIC-IV 2.0 ICU Mortality model is Epoch 149-Step 7650.

In [27]:
!wget https://uiuc-dlh.s3.amazonaws.com/epoch%3D149-step%3D7650.ckpt

--2024-04-14 05:07:53--  https://uiuc-dlh.s3.amazonaws.com/epoch%3D149-step%3D7650.ckpt
Resolving uiuc-dlh.s3.amazonaws.com (uiuc-dlh.s3.amazonaws.com)... 54.231.233.89, 54.231.129.161, 54.231.199.201, ...
Connecting to uiuc-dlh.s3.amazonaws.com (uiuc-dlh.s3.amazonaws.com)|54.231.233.89|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 152379276 (145M) [binary/octet-stream]
Saving to: ‘epoch=149-step=7650.ckpt’


2024-04-14 05:08:01 (20.5 MB/s) - ‘epoch=149-step=7650.ckpt’ saved [152379276/152379276]



##### Continue Fine-tuning and Evaluation

In [28]:
import gc

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import pytorch_lightning as pl
import argparse

import duett
import mimiciv_mortality

class WarmUpCallback(pl.callbacks.Callback):
    def __init__(self, steps=1000, base_lr=None, invsqrt=True, decay=None):
        print(f'warmup_steps {steps}, base_lr {base_lr}, invsqrt {invsqrt}, decay {decay}')
        self.warmup_steps = steps
        self.decay = decay if decay is not None else steps
        self.base_lr = base_lr
        self.invsqrt = invsqrt
        self.state = {'steps': 0, 'base_lr': float(base_lr) if base_lr is not None else None}

    def set_lr(self, optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def on_train_batch_start(self, trainer, model, batch, batch_idx):
        optimizers = model.optimizers()
        if self.state['steps'] < self.warmup_steps:
            if isinstance(optimizers, list):
                if self.state['base_lr'] is None:
                    self.state['base_lr'] = [o.param_groups[0]['lr'] for o in optimizers]
                for opt, base in zip(optimizers, self.state['base_lr']):
                    self.set_lr(opt, self.state['steps'] / self.warmup_steps * base)
            else:
                if self.state['base_lr'] is None:
                    self.state['base_lr'] = optimizers.param_groups[0]['lr']
                self.set_lr(optimizers, self.state['steps'] / self.warmup_steps * self.state['base_lr'])
            self.state['steps'] += 1
        elif self.invsqrt:
            if isinstance(optimizers, list):
                if self.state['base_lr'] is None:
                    self.state['base_lr'] = [o.param_groups[0]['lr'] for o in optimizers]
                for opt, base in zip(optimizers, self.state['base_lr']):
                    lr = base * (self.decay / (self.state['steps'] - self.warmup_steps + self.decay)) ** 0.5
                    self.set_lr(opt, lr)
            else:
                if self.state['base_lr'] is None:
                    self.state['base_lr'] = optimizers.param_groups[0]['lr']
                lr = self.state['base_lr'] * (self.decay / (self.state['steps'] - self.warmup_steps + self.decay)) ** 0.5
                self.set_lr(optimizers, lr)
            self.state['steps'] += 1

    def load_state_dict(self, state_dict):
        self.state.update(state_dict)

    def state_dict(self):
        return self.state.copy()

def average_models(models):
    models = list(models)
    n = len(models)
    sds = [m.state_dict() for m in models]
    averaged = {}
    for k in sds[0]:
        averaged[k] = sum(sd[k] for sd in sds) / n
    models[0].load_state_dict(averaged)
    return models[0]

seed = 2020
pl.seed_everything(seed)
dm = mimiciv_mortality.MIMICIVDataModule(data_path='../../mimic-iv-benchmarks/data/in-hospital-mortality/', batch_size=512, num_workers=30, use_temp_cache=False)
dm.setup()

#pretrained_path = 'checkpoints-mimiciv/last.ckpt'
#pretrain_model = duett.pretrain_model(d_static_num=dm.d_static_num(),
#        d_time_series_num=dm.d_time_series_num(), d_target=dm.d_target(), pos_frac=dm.pos_frac(),
#        seed=seed)
#checkpoint = pl.callbacks.ModelCheckpoint(save_last=True, monitor='val_loss', mode='min', save_top_k=1, dirpath='checkpoints-mimiciv')
#warmup = WarmUpCallback(steps=2000)
#trainer = pl.Trainer(gpus=1, logger=False, num_sanity_val_steps=2, max_epochs=300, # TODO: change back to 300
#                     gradient_clip_val=1.0, callbacks=[warmup, checkpoint],
#                     resume_from_checkpoint=pretrained_path)
#trainer.fit(pretrain_model, dm)

#pretrained_path = checkpoint.best_model_path
#print('best model path of pretraining:', pretrained_path)
#del pretrain_model, trainer
#gc.collect()
pretrained_path = 'epoch=149-step=7650.ckpt'
for seed in range(2020, 2021):
    pl.seed_everything(seed)
    fine_tune_model = duett.fine_tune_model(pretrained_path,
                                            d_static_num=dm.d_static_num(),
                                            d_time_series_num=dm.d_time_series_num(),
                                            d_target=dm.d_target(),
                                            pos_frac=dm.pos_frac(),
                                            seed=seed)
    checkpoint = pl.callbacks.ModelCheckpoint(save_top_k=5,
                                               save_last=False,
                                               mode='max',
                                               monitor='val_ap',
                                               dirpath='checkpoints-mimiciv')
    warmup = WarmUpCallback(steps=1000)
    trainer = pl.Trainer(gpus=1,
                         logger=False,
                         max_epochs=30, # TODO: change back to 30
                         gradient_clip_val=1.0,
                         callbacks=[warmup, checkpoint])
    trainer.fit(fine_tune_model, dm)
    final_model = average_models([duett.fine_tune_model(path,
                                                        d_static_num=dm.d_static_num(),
                                                        d_time_series_num=dm.d_time_series_num(),
                                                        d_target=dm.d_target(),
                                                        pos_frac=dm.pos_frac())
                                  for path in checkpoint.best_k_models.keys()])
    trainer.test(final_model, dataloaders=dm)
    del fine_tune_model, trainer, final_model
    gc.collect()

nohup: ignoring input
Global seed set to 2020
Loading train data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25820/25820 [15:41<00:00, 27.43it/s]
max_length = 1250, min_padding_length=86
Loading val data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5532/5532 [03:16<00:00, 28.09it/s]
max_length = 1250, min_padding_length=20
Loading test data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

#### Training PhysioNet-2012

##### Download Pre-trained Model

Pre-train code is at below which is commented out. 

Based on current pre-train progress, the best PhysioNet-2012 model is Epoch 235-Step 4012.

In [13]:
!wget https://uiuc-dlh.s3.amazonaws.com/physionet-best-pretrain-epoch%3D235-step%3D4012.ckpt

--2024-04-13 08:44:17--  https://uiuc-dlh.s3.amazonaws.com/physionet-best-pretrain-epoch%3D235-step%3D4012.ckpt
Resolving uiuc-dlh.s3.amazonaws.com (uiuc-dlh.s3.amazonaws.com)... 52.217.197.217, 52.216.129.123, 3.5.25.221, ...
Connecting to uiuc-dlh.s3.amazonaws.com (uiuc-dlh.s3.amazonaws.com)|52.217.197.217|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 48785659 (47M) [binary/octet-stream]
Saving to: ‘physionet-best-pretrain-epoch=235-step=4012.ckpt’


2024-04-13 08:44:20 (14.7 MB/s) - ‘physionet-best-pretrain-epoch=235-step=4012.ckpt’ saved [48785659/48785659]



##### Continue Fine-tuning and Evaluation

In [15]:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import pytorch_lightning as pl
import argparse

import duett
import physionet

class WarmUpCallback(pl.callbacks.Callback):
    """Linear warmup over warmup_steps batches, tries to auto-detect the base lr"""
    def __init__(self, steps=1000, base_lr=None, invsqrt=True, decay=None):
        print('warmup_steps {}, base_lr {}, invsqrt {}, decay {}'.format(steps, base_lr, invsqrt, decay))
        self.warmup_steps = steps
        if decay is None:
            self.decay = steps
        else:
            self.decay = decay

        if base_lr is None:
            self.state = {'steps': 0, 'base_lr': base_lr}
        else:
            self.state = {'steps': 0, 'base_lr': float(base_lr)}

        self.invsqrt = invsqrt

    def set_lr(self, optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def on_train_batch_start(self, trainer, model, batch, batch_idx):
        optimizers = model.optimizers()

        if self.state['steps'] < self.warmup_steps:
            if type(optimizers) == 'list':
                if self.state['base_lr'] is None:
                    self.state['base_lr'] = [o.param_groups[0]['lr'] for o in optimizers]
                for opt,base in zip(optimizers, self.state['base_lr']):
                    self.set_lr(opt, self.state['steps']/self.warmup_steps * base)
            else:
                if self.state['base_lr'] is None:
                    self.state['base_lr'] = optimizers.param_groups[0]['lr']
                self.set_lr(optimizers, self.state['steps']/self.warmup_steps * self.state['base_lr'])
            self.state['steps'] += 1
        elif self.invsqrt:
            if type(optimizers) == 'list':
                if self.state['base_lr'] is None:
                    self.state['base_lr'] = [o.param_groups[0]['lr'] for o in optimizers]
                for opt,base in zip(optimizers, self.state['base_lr']):
                    self.set_lr(opt,base * (self.decay / (self.state['steps'] - self.warmup_steps + self.decay)) ** 0.5)
            else:
                if self.state['base_lr'] is None:
                    self.state['base_lr'] = optimizers.param_groups[0]['lr']
                self.set_lr(optimizers, self.state['base_lr'] * (
                            self.decay / (self.state['steps'] - self.warmup_steps + self.decay)) ** 0.5)
            self.state['steps'] += 1

    def load_state_dict(self, state_dict):
        self.state.update(state_dict)

    def state_dict(self):
        return self.state.copy()

def average_models(models):
    """Averages model weights and loads the resulting weights into the first model, returning it"""
    models = list(models)
    n = len(models)
    sds = [m.state_dict() for m in models]
    averaged = {}
    for k in sds[0]:
        averaged[k] = sum(sd[k] for sd in sds) / n
    models[0].load_state_dict(averaged)
    return models[0]

seed = 2020
pl.seed_everything(seed)
dm = physionet.PhysioNetDataModule(batch_size=512, num_workers=30, use_temp_cache=True)
dm.setup()
#pretrain_model = duett.pretrain_model(d_static_num=dm.d_static_num(),
#        d_time_series_num=dm.d_time_series_num(), d_target=dm.d_target(), pos_frac=dm.pos_frac(),
#        seed=seed)
#checkpoint = pl.callbacks.ModelCheckpoint(save_last=True, monitor='val_loss', mode='min', save_top_k=1, dirpath='checkpoints')
#warmup = WarmUpCallback(steps=2000)
#trainer = pl.Trainer(gpus=1, logger=False, num_sanity_val_steps=2, max_epochs=300,
#        gradient_clip_val=1.0, callbacks=[warmup, checkpoint])
#trainer.fit(pretrain_model, dm)

pretrained_path = "physionet-best-pretrain-epoch=235-step=4012.ckpt"
for seed in range(2020, 2023):
    pl.seed_everything(seed)
    fine_tune_model = duett.fine_tune_model(pretrained_path, d_static_num=dm.d_static_num(),
            d_time_series_num=dm.d_time_series_num(), d_target=dm.d_target(), pos_frac=dm.pos_frac(), seed=seed)
    checkpoint = pl.callbacks.ModelCheckpoint(save_top_k=5, save_last=False, mode='max', monitor='val_ap', dirpath='checkpoints')
    warmup = WarmUpCallback(steps=1000)
    trainer = pl.Trainer(gpus=1, logger=False, max_epochs=50, gradient_clip_val=1.0,
            callbacks=[warmup, checkpoint], progress_bar_refresh_rate=0, log_every_n_steps=100)
    trainer.fit(fine_tune_model, dm)
    final_model = average_models([duett.fine_tune_model(path, d_static_num=dm.d_static_num(),
            d_time_series_num=dm.d_time_series_num(), d_target=dm.d_target(), pos_frac=dm.pos_frac())
            for path in checkpoint.best_k_models.keys()])
    trainer.test(final_model, dataloaders=dm)


Global seed set to 2020


Validating cache...
Validating cache...
Validating cache...


Global seed set to 2020
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loading from checkpoint
warmup_steps 1000, base_lr None, invsqrt True, decay None
Validating cache...
Validating cache...


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  super(AdamW, self).__init__(params, defaults)

   | Name                         | Type             | Params
-------------------------------------------------------------------
0  | special_embeddings           | Embedding        | 192   
1  | embedding_layers             | ModuleList       | 67.7 K
2  | n_obs_embedding              | Embedding        | 16    
3  | event_transformers           | ModuleList       | 1.8 M 
4  | full_event_embedding         | Embedding        | 29.3 K
5  | time_transformers            | ModuleList       | 2.0 M 
6  | full_time_embedding          | Sequential       | 26.8 K
7  | full_rep_embedding           | Embedding        | 888   
8  | head                         | Sequential       | 57.1 K
9  | pretrain_value_proj          | Sequential       | 32.0 K
10 | pretrain_presence_proj       | Sequential       | 32.0 K
11 | predict_events_

val_auroc tensor(0.4978, device='cuda:0') val_ap tensor(0.1556, device='cuda:0')
val_auroc tensor(0.5335, device='cuda:0') val_ap tensor(0.1659, device='cuda:0')
val_auroc tensor(0.6009, device='cuda:0') val_ap tensor(0.2029, device='cuda:0')
val_auroc tensor(0.6966, device='cuda:0') val_ap tensor(0.2792, device='cuda:0')
val_auroc tensor(0.7452, device='cuda:0') val_ap tensor(0.3362, device='cuda:0')
val_auroc tensor(0.7763, device='cuda:0') val_ap tensor(0.3814, device='cuda:0')
val_auroc tensor(0.7952, device='cuda:0') val_ap tensor(0.3995, device='cuda:0')
val_auroc tensor(0.8143, device='cuda:0') val_ap tensor(0.4416, device='cuda:0')
val_auroc tensor(0.8210, device='cuda:0') val_ap tensor(0.4404, device='cuda:0')
val_auroc tensor(0.8287, device='cuda:0') val_ap tensor(0.4754, device='cuda:0')
val_auroc tensor(0.8358, device='cuda:0') val_ap tensor(0.4882, device='cuda:0')
val_auroc tensor(0.8445, device='cuda:0') val_ap tensor(0.5064, device='cuda:0')
val_auroc tensor(0.8433, dev



Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Validating cache...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 2021


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         test_ap            0.5473765134811401
       test_auroc            0.869416356086731
        test_loss           0.48961365012943475
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loading from checkpoint
warmup_steps 1000, base_lr None, invsqrt True, decay None
Validating cache...
Validating cache...


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  super(AdamW, self).__init__(params, defaults)

   | Name                         | Type             | Params
-------------------------------------------------------------------
0  | special_embeddings           | Embedding        | 192   
1  | embedding_layers             | ModuleList       | 67.7 K
2  | n_obs_embedding              | Embedding        | 16    
3  | event_transformers           | ModuleList       | 1.8 M 
4  | full_event_embedding         | Embedding        | 29.3 K
5  | time_transformers            | ModuleList       | 2.0 M 
6  | full_time_embedding          | Sequential       | 26.8 K
7  | full_rep_embedding           | Embedding        | 888   
8  | head                         | Sequential       | 57.1 K
9  | pretrain_value_proj          | Sequential       | 32.0 K
10 | pretrain_presence_proj       | Sequential       | 32.0 K
11 | predict_events_

val_auroc tensor(0.4978, device='cuda:0') val_ap tensor(0.1556, device='cuda:0')
val_auroc tensor(0.5340, device='cuda:0') val_ap tensor(0.1666, device='cuda:0')
val_auroc tensor(0.6134, device='cuda:0') val_ap tensor(0.2129, device='cuda:0')
val_auroc tensor(0.6920, device='cuda:0') val_ap tensor(0.2776, device='cuda:0')
val_auroc tensor(0.7415, device='cuda:0') val_ap tensor(0.3287, device='cuda:0')
val_auroc tensor(0.7696, device='cuda:0') val_ap tensor(0.3600, device='cuda:0')
val_auroc tensor(0.7928, device='cuda:0') val_ap tensor(0.3872, device='cuda:0')
val_auroc tensor(0.8027, device='cuda:0') val_ap tensor(0.4015, device='cuda:0')
val_auroc tensor(0.8138, device='cuda:0') val_ap tensor(0.4246, device='cuda:0')
val_auroc tensor(0.8204, device='cuda:0') val_ap tensor(0.4387, device='cuda:0')
val_auroc tensor(0.8375, device='cuda:0') val_ap tensor(0.4823, device='cuda:0')
val_auroc tensor(0.8404, device='cuda:0') val_ap tensor(0.4870, device='cuda:0')
val_auroc tensor(0.8452, dev



Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Validating cache...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 2022


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         test_ap             0.564518928527832
       test_auroc           0.8703322410583496
        test_loss           0.46560638726343934
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loading from checkpoint
warmup_steps 1000, base_lr None, invsqrt True, decay None
Validating cache...
Validating cache...


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  super(AdamW, self).__init__(params, defaults)

   | Name                         | Type             | Params
-------------------------------------------------------------------
0  | special_embeddings           | Embedding        | 192   
1  | embedding_layers             | ModuleList       | 67.7 K
2  | n_obs_embedding              | Embedding        | 16    
3  | event_transformers           | ModuleList       | 1.8 M 
4  | full_event_embedding         | Embedding        | 29.3 K
5  | time_transformers            | ModuleList       | 2.0 M 
6  | full_time_embedding          | Sequential       | 26.8 K
7  | full_rep_embedding           | Embedding        | 888   
8  | head                         | Sequential       | 57.1 K
9  | pretrain_value_proj          | Sequential       | 32.0 K
10 | pretrain_presence_proj       | Sequential       | 32.0 K
11 | predict_events_

val_auroc tensor(0.4978, device='cuda:0') val_ap tensor(0.1556, device='cuda:0')
val_auroc tensor(0.5371, device='cuda:0') val_ap tensor(0.1680, device='cuda:0')
val_auroc tensor(0.6192, device='cuda:0') val_ap tensor(0.2141, device='cuda:0')
val_auroc tensor(0.6911, device='cuda:0') val_ap tensor(0.2726, device='cuda:0')
val_auroc tensor(0.7511, device='cuda:0') val_ap tensor(0.3421, device='cuda:0')
val_auroc tensor(0.7776, device='cuda:0') val_ap tensor(0.3757, device='cuda:0')
val_auroc tensor(0.7938, device='cuda:0') val_ap tensor(0.3904, device='cuda:0')
val_auroc tensor(0.8117, device='cuda:0') val_ap tensor(0.4197, device='cuda:0')
val_auroc tensor(0.8180, device='cuda:0') val_ap tensor(0.4454, device='cuda:0')
val_auroc tensor(0.8266, device='cuda:0') val_ap tensor(0.4508, device='cuda:0')
val_auroc tensor(0.8300, device='cuda:0') val_ap tensor(0.4690, device='cuda:0')
val_auroc tensor(0.8432, device='cuda:0') val_ap tensor(0.4900, device='cuda:0')
val_auroc tensor(0.8498, dev



Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Loading from checkpoint
Validating cache...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         test_ap            0.5617507100105286
       test_auroc            0.869026780128479
        test_loss           0.47735832650971227
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


## Evaluation
1. Metrics descriptions
- Average Precision: This metric calculates the mean precision of the model at different recall levels by summarizing the precision-recall curve.
- AUROC: AUROC measures the likelihood that the model ranks a randomly chosen positive instance higher than a randomly chosen negative one across all possible thresholds.
- Average Loss: Average Loss quantifies the model’s mean prediction error across all samples, providing an overall measure of its performance.
   
2. Implementation code (embeded in above training code segment)

- Helper function:
```
def average_models(models):
    """Averages model weights and loads the resulting weights into the first model, returning it"""
    models = list(models)
    n = len(models)
    sds = [m.state_dict() for m in models]
    averaged = {}
    for k in sds[0]:
        averaged[k] = sum(sd[k] for sd in sds) / n
    models[0].load_state_dict(averaged)
    return models[0]
```
- Average of top 5 during fine-tuning
```
final_model = average_models([duett.fine_tune_model(path, d_static_num=dm.d_static_num(),
            d_time_series_num=dm.d_time_series_num(), d_target=dm.d_target(), pos_frac=dm.pos_frac())
            for path in checkpoint.best_k_models.keys()])
    trainer.test(final_model, dataloaders=dm)
```

# Results
## 1. Results

### MIMIC-IV 2.0 ICU Mortality - Approaching Paper Benchmark
(Fine-tuned from pre-training Epoch 149 only, pre-training is half done due to limited GPU capability)
```
    test_ap             0.5389033555984497
    test_auroc          0.8587720990180969
    test_loss           0.5029744141557028
```
VS.

![](https://uiuc-dlh.s3.amazonaws.com/duett_paper_mimic_result.png)

### PhysioNet-2012 Mortality - Hit Paper Benchmark
```
(Seed 2021)
    test_ap             0.564518928527832
    test_auroc          0.8703322410583496
    test_loss           0.46560638726343934
```
VS.

![](https://uiuc-dlh.s3.amazonaws.com/duett_paper_physio_result.png)

## 2. Analysis
- The results for the PhysioNet-2012 dataset have been successfully replicated without any inconsistencies.
- Due to budget constraints and the limited capacity of the GPU setup, the large MIMIC-IV dataset requires additional training time to achieve optimal results, as the current progress is still on track to match the best outcomes.
- The MIMIC-IV ICU Mortality task has seen significant improvement due to careful feature engineering. The AUROC has increased from 0.82 to 0.858 as I have progressively implemented and corrected further details in accordance with the paper and my own discoveries.
- For the MIMIC-IV ICU Mortality task, focusing on feature engineering and hyperparameter tuning (both are not fully disclosed in original paper) is key to refining the model's performance, assuming that the model implementation from the paper is accurate.
## 3. Plan
- Continue with pre-training on the MIMIC-IV dataset and will update the model with the best parameters to proceed with fine-tuning and testing.Persist with pre-training on the MIMIC-IV dataset and will update the model with the best parameters to proceed with fine-tuning and testing.
- Continue with ablation tests; however, these plans may be subject to change depending on the increasing budgetary demands.with ablation tests; however, these plans may be subject to change depending on the increasing budgetary demands.