In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
import torch
import random
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, confusion_matrix
from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score, average_precision_score


metrics = ['AUC', 'F1 score', 'Recall', 'NMI']
seeds = [1001, 1012, 1134, 2475, 6138, 7415, 1663, 7205, 9253, 1782]



In [2]:
#!/usr/bin/env python3
# INSERT FILE DESCRIPTION

"""
Util functions to run Model. Includes Data loading, etc...
"""
import os
from typing import List, Union

import numpy as np
import pandas as pd

from tqdm import tqdm
import datetime as dt

tqdm.pandas()


def _compute_last_target_id(df: pd.DataFrame, time_col: str = "intime", mode: str = "max") -> pd.DataFrame:
    """Identify last ids given df according to time given by time_col column. Mode determines min or max."""
    if mode == "max":
        time = df[time_col].max()
    elif mode == "min":
        time = df[time_col].min()
    else:
        raise ValueError("mode must be one of ['min', 'max']. Got {}".format(mode))

    last_ids = df[df[time_col] == time]

    return last_ids


def _rows_are_in(df1: pd.DataFrame, df2: pd.DataFrame, matching_columns: Union[List[str], str]) -> pd.DataFrame:
    """
    Checks if values present in row of df1 exist for all columns in df2. Note that this does not necessarily mean
    the whole row of df1 is in df2, but is good enough for application.

    Returns: array of indices indicating the relevant rows of df1.
    """
    if isinstance(matching_columns, str):
        matching_columns = [matching_columns]

    # Iterate through each column
    matching_ids = np.ones(df1.shape[0])
    for col in tqdm(matching_columns):
        col_matching = df1[col].isin(df2[col].values).values  # where df1 col is subset of df2 col
        matching_ids = np.logical_and(matching_ids, col_matching)  # match with columns already looked at

    return matching_ids


def _compute_second_transfer_info(df: pd.DataFrame, time_col, target_cols):
    """
    Given transfer data for a unique id, compute the second transfer as given by time_col.

    return: pd.Series with corresponding second transfer info.
    """
    time_info = df[time_col]
    second_transfer_time = time_info[time_info != time_info.min()].min()

    # Identify second transfer info - can be empty, unique, or repeated instances
    second_transfer = df[df[time_col] == second_transfer_time]

    if second_transfer.empty:
        output = [df.name, df["hadm_id"].iloc[0], df["transfer_id"].iloc[0]] + [np.nan] * (len(target_cols) - 3)
        return pd.Series(data=output, index=target_cols)

    elif second_transfer.shape[0] == 1:
        return pd.Series(data=second_transfer.squeeze().values, index=target_cols)

    else:  # There should be NONE
        print(second_transfer)
        raise ValueError("Something's gone wrong! No expected repeated second transfers with the same time.")


def convert_columns_to_dt(df: pd.DataFrame, columns: Union[str, List[str]]):
    """Convert columns of dataframe to datetime format, as per given"""
    if isinstance(columns, str):
        columns = [columns]

    for col in columns:
        df[col] = pd.to_datetime(df.loc[:, col].values)

    return df


def subsetted_by(df1: pd.DataFrame, df2: pd.DataFrame, matching_columns: Union[List[str], str]) -> pd.DataFrame:
    """
    Subset df1 based on matching_columns, according to values existing in df2.

    Returns: pd.DataFrame subset of df1 for which rows are a subset of df2
    """

    return df1.iloc[_rows_are_in(df1, df2, matching_columns), :]


def endpoint_target_ids(df: pd.DataFrame, identifier: str, time_col: str = "intime", mode: str = "max") -> pd.DataFrame:
    """
    Given identifier target ("id"), compute the endpoint associated with time column.

    Returns: pd.DataFrame with ids and associated endpoint information.
    """
    last_ids = df.groupby(identifier, as_index=False).progress_apply(
        lambda x: _compute_last_target_id(x, time_col=time_col, mode=mode))

    return last_ids.reset_index(drop=True)


def compute_second_transfer(df: pd.DataFrame, identifier: str, time_col: str, target_cols: pd.Index) -> pd.DataFrame:
    """
    Given transfer data represented by unique identifier ("id"), compute the second transfer of the admission.
    Second Transfer defined as second present intime in the date (if multiple, this is flagged). If there are
    no transfers after, then return np.nan. target_cols is the target information columns.

    This function checks the second transfer intime is after outtime of first transfer record.

    Returns: pd.DataFrame with id and associated second transfer information (intime/outtime, unit, etc...)
    """
    second_transfer_info = df.groupby(identifier, as_index=False).progress_apply(
        lambda x: _compute_second_transfer_info(x, time_col, target_cols))

    return second_transfer_info.reset_index(drop=True)


def _has_many_nas(df: pd.DataFrame, targets: Union[List[str], str], min_count: int, min_frac: float) -> bool:
    """
    For a given admission/stay with corresponding vital sign information, return boolean indicating whether low
    missingness conditions are satisfied. These are:
    a) At least min_count observations.
    b) Proportion of missing values smaller than min_frac for ALL targets.

    returns: boolean indicating admission should be kept.
    """
    if isinstance(targets, str):
        targets = [targets]

    has_minimum_counts = df.shape[0] > min_count
    has_less_NA_than_frac = df[targets].isna().sum() <= min_frac * df.shape[0]

    return has_minimum_counts and has_less_NA_than_frac.all()


def remove_adms_high_missingness(df: pd.DataFrame, targets: Union[List[str], str],
                                 identifier: str, min_count: int, min_frac: float) -> pd.DataFrame:
    """
    Given vital sign data, remove admissions with too little information. This is defined as either:
    a) Number of observations smaller than allowed min_count.
    b) Proportion of missing values in ANY of the targets is higher than min_frac.

    Returns: pd.DataFrame - Vital sign data of the same type, except only admissions with enough information are kept.
    """
    output = df.groupby(identifier, as_index=False).filter(
        lambda x: _has_many_nas(x, targets, min_count, min_frac))

    return output.reset_index(drop=True)


def _resample_adm(df: pd.DataFrame, rule: str, time_id: str,
                  time_vars: Union[List[str], str], static_vars: Union[List[str], str]) -> pd.DataFrame:
    """
    For a particular stay with vital sign data as per df, resample trajectory data (subsetted to time_vars),
    according to index given by time_to_end and as defined by rule. It is important that time_to_end decreases
    throughout admissions and hits 0 at the end - this is for resampling purposes.

    Params:
    df: pd.Dataframe, containing trajectory and static data for each admission.
    rule: str, indicates the resampling rule (to be fed to pd.DataFrame.resample())

    static_vars is a list of relevant identifier information

    returns: Resampled admission data. Furthermore, two more info columns are indicated (chartmax and chartmin).
    """
    if isinstance(time_vars, str):
        time_vars = [time_vars]

    if isinstance(static_vars, str):
        static_vars = [static_vars]

    # Add fake observation (with missing values) so that resampling starts at end of admission
    df_inter = df[time_vars + ["time_to_end"]]
    df_inter = df_inter.append(pd.Series(data=[np.nan] * len(time_vars) + [dt.timedelta(seconds=0)],
                                         index=df_inter.columns), ignore_index=True)

    # resample on time_to_end axis
    output = df_inter.sort_values(by="time_to_end", ascending=False).resample(
        on="time_to_end",
        rule=rule, closed="left", label="left").mean()

    # Compute static ids manually and add information about max and min time id values
    output[static_vars] = df[static_vars].iloc[0, :].values
    output[time_id + "_min"] = df[time_id].min()
    output[time_id + "_max"] = df[time_id].max()

    # Reset index to obtain resampled values
    output.index.name = f"sampled_time_to_end({rule})"
    output.reset_index(drop=False, inplace=True)

    return output


def compute_time_to_end(df: pd.DataFrame, id_key: str, time_id: str, end_col: str):
    """
    Compute time to end of admission for a given observation associated with a particular admission id.

    df: pd.DataFrame with trajectory information.
    id_key: str - column of df representing the unique id admission identifier.
    time_id: str - column of df indicating time observations was taken.
    end_col: str - column of df indicating, for each observation, the end time of the corresponding admission.

    returns: sorted pd.DataFrame with an extra column indicating time to end of admission. This will be used for
    resampling.
    """
    df_inter = df.copy()
    df_inter["time_to_end"] = df_inter[end_col] - df_inter[time_id]
    df_inter.sort_values(by=[id_key, "time_to_end"], ascending=[True, False], inplace=True)

    return df_inter


def conversion_to_block(df: pd.DataFrame, id_key: str, rule: str,
                        time_vars: Union[List[str], str], static_vars: Union[List[str], str]) -> pd.DataFrame:
    """
    Given trajectory data over multiple admissions (as specified by id), resample each admission according to time
    until the end of the admission. Resampling according to rule and apply to_time_vars.

    df: pd.DataFrame containing trajectory and static data.
    id_key: str, unique identifier per admission
    rule: str, indicates resampling rule (to be fed to pd.DataFrame.resample())
    time_vars: list of str, indicates columns of df to be resampled.
    static_vars: list of str, indicates columns of df which are static, and therefore not resampled.

    return: Dataframe with resampled vital sign data.
    """
    if "time_to_end" not in df.columns:
        raise ValueError("'time_to_end' not found in columns of dataframe. Run 'compute_time_to_end' function first.")
    assert df[id_key].is_monotonic and df.groupby(id_key).apply(
        lambda x: x["time_to_end"].is_monotonic_decreasing).all()

    # Resample admission according to time_to_end
    output = df.groupby(id_key).progress_apply(lambda x: _resample_adm(x, rule, "time_to_end", time_vars, static_vars))

    return output.reset_index(drop=True)


def convert_to_timedelta(df: pd.DataFrame, *args) -> pd.DataFrame:
    """Convert all given cols of dataframe to timedelta."""
    output = df.copy()
    for arg in args:
        output[arg] = pd.to_timedelta(df.loc[:, arg])

    return output


def _check_all_tables_exist(folder_path: str):
    """TO MOVE TO TEST"""
    try:
        assert os.path.exists(folder_path)
    except Exception:
        raise ValueError("Folder path does not exist - Input {}".format(folder_path))


def select_death_icu_acute(df, admissions_df, timedt):
    """
    Identify outcomes based on severity within the consequent 12 hours:
    a) Death
    b) Entry to ICU Careunit
    c) Transfer to hospital ward
    d) Discharge

    Params:
    - df - transfers dataframe corresponding to a particular admission.
    - timedt - datetime timedelta indicating range window of prediction

    Returns categorical encoding of the corresponding admission.
    Else returns 0,0,0,0 if a mistake is found.
    """
    # Check admission contains only one such row
    assert admissions_df.hadm_id.eq(df.name).sum() <= 1

    # Identify Last observed vitals for corresponding admission
    hadm_information = admissions_df.query("hadm_id==@df.name").iloc[0, :]
    window_start_point = hadm_information.loc["outtime"]

    # First check if death exists
    hadm_information = admissions_df.query("hadm_id==@df.name")
    if not hadm_information.empty and not hadm_information.dod.isna().all():
        time_of_death = hadm_information.dod.min()
        time_from_start_point = (time_of_death - window_start_point)

        # try:
        #     assert time_from_vitals >= dt.timedelta(seconds=0)
        #
        # except AssertionError:
        #     return pd.Series(data=[0, 0, 0, 0, time_of_death], index=["De", "I", "W", "Di", "time"])

        # Check death within time window
        if time_from_start_point < timedt:
            return pd.Series(data=[1, 0, 0, 0, time_of_death], index=["De", "I", "W", "Di", "time"])

    # Otherwise, consider other transfers
    transfers_within_window = df[df["intime"].between(window_start_point, window_start_point + timedt)]

    # Consider icu transfers within window
    icu_cond1 = transfers_within_window.careunit.str.contains("(?i)ICU", na=False)  # regex ignore lowercase
    icu_cond2 = transfers_within_window.careunit.str.contains("(?i)Neuro Stepdown", na=False)
    has_icus = (icu_cond1 | icu_cond2)

    if has_icus.sum() > 0:
        icu_transfers = transfers_within_window[has_icus]
        return pd.Series(data=[0, 1, 0, 0, icu_transfers.intime.min()],
                         index=["De", "I", "W", "Di", "time"])

    # Check to see if discharge has taken
    discharges = transfers_within_window.eventtype.str.contains("discharge", na=False)
    if discharges.sum() > 0:
        return pd.Series(data=[0, 0, 0, 1, transfers_within_window[discharges].intime.min()],
                         index=["De", "I", "W", "Di", "time"]
                         )
    else:
        return pd.Series(data=[0, 0, 1, 0, transfers_within_window.intime.min()],
                         index=["De", "I", "W", "Di", "time"]
                         )


In [9]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader



# ---------------------------------------------------------------------------------------
"Global variables for specific dataset information loading."

MIMIC_PARSE_TIME_VARS = ["intime", "outtime", "chartmax"]
MIMIC_PARSE_TD_VARS = [
    "sampled_time_to_end(1H)", "time_to_end", "time_to_end_min", "time_to_end_max"]
MIMIC_VITALS = ["TEMP", "HR", "RR", "SPO2", "SBP", "DBP"]
MIMIC_STATIC = ["age", "gender", "ESI"]
MIMIC_OUTCOME_NAMES = ["De", "I", "W", "Di"]

# Identifiers for main ids.
MAIN_ID_LIST = ["subject_id", "hadm_id", "stay_id", "patient_id", "pat_id"]

# ----------------------------------------------------------------------------------------


class CustomDataset(Dataset):

    def __init__(self, data_name="MIMIC", target_window=12, feat_set='vit-sta', time_range=(0, 6), parameters=None):
        if parameters is None:
            self.data_name = data_name
            self.target_window = target_window
            self.feat_set = feat_set
            self.time_range = time_range
            self.id_col = None
            self.time_col = None
            self.needs_time_to_end_computation = False
            self.min = None
            self.max = None

            # Load and process data
            self.id_col, self.time_col, self.needs_time_to_end_computation = self.get_ids(
                self.data_name)
            self.x, self.y, self.mask, self.pat_time_ids, self.features, self.outcomes, self.x_subset, self.y_data = self.load_transform()
        else:
            self.x, self.y, self.mask, self.pat_time_ids, self.features, self.outcomes, self.x_subset, self.y_data, self.id_col, self.time_col, self.needs_time_to_end_computation, self.data_name, self.feat_set, self.time_range, self.target_window, self.min, self.max = parameters

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        # Extract data for given index
        x = self.x[idx, :, :]
        y = self.y[idx, :]
        mask = self.mask[idx, :, :]
        pat_time_ids = self.pat_time_ids[idx, :, :]
        features = self.features
        outcomes = self.outcomes
        x_subset = self.x_subset[idx, :]
        y_data = self.y_data[idx, :]
        id_col = self.id_col
        time_col = self.time_col
        needs_time_to_end_computation = self.needs_time_to_end_computation
        data_name = self.data_name
        feat_set = self.feat_set
        time_range = self.time_range
        target_window = self.target_window
        min = self.min
        max = self.max
        return x, y, mask, pat_time_ids, features, outcomes, x_subset, y_data, id_col, time_col, needs_time_to_end_computation, data_name, feat_set, time_range, target_window, min, max

    def get_subset(self, idx):
        return CustomDataset(parameters=self[idx])

    def load_transform(self):
        """Load dataset and transform to input format"""

        # Load data
        data = self._load(self.data_name, window=self.target_window)

        # Get data info
        self.id_col, self.time_col, self.needs_time_to_end_computation = self.get_ids(
            self.data_name)

        # Add time to end and truncate if needed
        # print(data[0].shape, '0')
        x_inter = self._add_time_to_end(data[0])
        # print(x_inter.shape, '1')
        x_inter = self._truncate(x_inter)
        # print(x_inter.shape, '2')
        self._check_correct_time_conversion(x_inter)

        # Subset to relevant features (keeps self.id_col and self.time_col still)
        # print(x_inter.shape, '3')
        x_subset, features = self.subset_to_features(x_inter)

        # Convert to 3D array
        # print(x_inter.shape, '4')
        x_inter, pat_time_ids = self.convert_to_3darray(x_subset)
        x_subset = x_subset.to_numpy().astype(np.float32)

        # Normalise array
        # print(x_inter.shape, '5')
        x_inter = self.normalise(x_inter)

        # Impute missing values
        # print(x_inter.shape, '6')
        x_out, mask = self.impute(x_inter)
        # print(x_out.shape, '7')

        # Do things to y
        outcomes = self._get_outcome_names(self.data_name)
        y_data = data[1][outcomes]
        y_out = y_data.to_numpy().astype("float32")
        y_data = y_data.to_numpy().astype("float32")

        # Check data loaded correctly
        self._check_input_format(x_out, y_out)

        return x_out, y_out, mask, pat_time_ids, features, outcomes, x_subset, y_data

    def _load(self, data_name, window=4):
        """Load Trajectory, Target data jointly given data folder."""

        # Make data folder
        # data_fd = f"/kaggle/input/mimic-processed/"
        data_fd = f"/kaggle/input/mimic-processed/"
        try:
            os.path.exists(data_fd)
        except AssertionError:
            print(data_fd)

        if "MIMIC" in data_name:

            # Load Data
            X = pd.read_csv(data_fd + "vitals_process.csv",
                            parse_dates=MIMIC_PARSE_TIME_VARS, header=0, index_col=0)
            y = pd.read_csv(
                data_fd + f"outcomes_{window}h_process.csv", index_col=0)
            # for Kaggle:
#             X = pd.read_csv("vitals_process.csv", parse_dates=MIMIC_PARSE_TIME_VARS, header=0, index_col=0)
#             y = pd.read_csv(f"outcomes_{window}h_process.csv", index_col=0)

            # Convert columns to timedelta
            X = convert_to_timedelta(X, *MIMIC_PARSE_TD_VARS)

        elif "SAMPLE" in data_name:

            # Load data
            X = None
            y = None

        else:
            raise ValueError(
                f"Data Name does not match available datasets. Input Folder provided {data_fd}")
        return X, y

    def get_ids(self, data_name):
        """
        Get input id information.

        Params:
        - data_folder: str, folder of dataset, or name of dataset.

        Returns:
            - Tuple of id col, time col and whether time to end needs computation.
        """

        if "MIMIC" in data_name:
            id_col, time_col, needs_time_to_end = "hadm_id", "sampled_time_to_end(1H)", False

        elif "SAMPLE" in data_name:
            id_col, time_col, needs_time_to_end = None, None, None

        else:
            raise ValueError(
                f"Data Name does not match available datasets. Input Folder provided {data_name}")

        return id_col, time_col, needs_time_to_end

    def impute(self, X):
        """
        Imputation of 3D array accordingly with time as dimension 1:
        1st - forward value propagation,
        2nd - backwards value propagation,
        3rd - median value imputation.

        Mask returned at the end, corresponding to original missing values.
        """
        impute_step1 = self._numpy_forward_fill(X)
        impute_step2 = self._numpy_backward_fill(impute_step1)
        impute_step3 = self._median_fill(impute_step2)

        # Compute mask
        mask = np.isnan(X)

        return impute_step3, mask

    def convert_datetime_to_hour(self, series):
        """Convert pandas Series of datetime values to float Series with corresponding hour values"""
        seconds_per_hour = 3600

        return series.dt.total_seconds() / seconds_per_hour

    def _get_features(self, key, data_name="MIMIC"):
        """
        Compute list of features to keep given key. Key can be one of:
        - str, where the corresponding features are selected according to the fn below.
        - list, where the corresponding features are the original list.
        """
        if isinstance(key, list):
            return key

        elif isinstance(key, str):
            if data_name == "MIMIC":
                vitals = MIMIC_VITALS
                static = MIMIC_STATIC
                vars1, vars2 = None, None

            elif data_name == "SAMPLE":
                vitals, vars1, vars2, static = None, None, None, None

            else:
                raise ValueError(
                    f"Data Name does not match available datasets. Input provided {data_name}")

            # Add features given substrings of key. We initialise set in case of repetition (e.g. 'vars1-lab')
            features = set([])
            if "vit" in key.lower():
                features.update(vitals)

            if "vars1" in key.lower():
                features.update(vars1)

            if "vars2" in key.lower():
                features.update(vars2)

            if "lab" in key.lower():
                features.update(vars1)
                features.update(vars2)

            if "sta" in key.lower():
                features.update(static)

            if "all" in key.lower():
                features = self._get_features("vit-lab-sta", data_name)

            # sorted returns a list of features.
            sorted_features = sorted(features)
            print(
                f"\n{data_name} data has been subsettted to the following features: \n {sorted_features}.")

            return sorted_features

        else:
            raise TypeError(
                f"Argument key must be one of type str or list, type {type(key)} was given.")

    def _numpy_forward_fill(self, array):
        """Forward Fill a numpy array. Time index is axis = 1."""
        array_mask = np.isnan(array)
        array_out = np.copy(array)

        # Add time indices where not masked, and propagate forward
        inter_array = np.where(~ array_mask, np.arange(
            array_mask.shape[1]).reshape(1, -1, 1), 0)
        np.maximum.accumulate(inter_array, axis=1,
                              out=inter_array)  # For each (n, t, d) missing value, get the previously accessible mask value

        # Index matching for output. For n, d sample as previously, use inter_array for previous time id
        array_out = array_out[np.arange(array_out.shape[0])[:, None, None],
                              inter_array,
                              np.arange(array_out.shape[-1])[None, None, :]]

        return array_out

    def _numpy_backward_fill(self, array):
        """Backward Fill a numpy array. Time index is axis = 1"""
        array_mask = np.isnan(array)
        array_out = np.copy(array)

        # Add time indices where not masked, and propagate backward
        inter_array = np.where(~ array_mask, np.arange(
            array_mask.shape[1]).reshape(1, -1, 1), array_mask.shape[1] - 1)
        inter_array = np.minimum.accumulate(
            inter_array[:, ::-1], axis=1)[:, ::-1]
        array_out = array_out[np.arange(array_out.shape[0])[:, None, None],
                              inter_array,
                              np.arange(array_out.shape[-1])[None, None, :]]

        return array_out

    def _median_fill(self, array):
        """Median fill a numpy array. Time index is axis = 1"""
        array_mask = np.isnan(array)
        array_out = np.copy(array)

        # Compute median and impute
        array_med = np.nanmedian(np.nanmedian(
            array, axis=0, keepdims=True), axis=1, keepdims=True)
        array_out = np.where(array_mask, array_med, array_out)

        return array_out

    def _get_outcome_names(self, data_name):
        """Return the corresponding outcome columns given dataset name."""
        if data_name == "MIMIC":
            return MIMIC_OUTCOME_NAMES

        elif data_name == "SAMPLE":
            return None

    def _check_input_format(self, X, y):
        """Check conditions to confirm model input."""

        try:
            # Length and shape conditions
            # print(X.shape, y.shape)
            cond1 = X.shape[0] == y.shape[0]
            cond2 = len(X.shape) == 3
            cond3 = len(y.shape) == 2

            # Check non-missing values
            cond4 = np.sum(np.isnan(X)) + np.sum(np.isnan(y)) == 0

            # Check y output is one hot encoded
            cond5 = np.all(np.sum(y, axis=1) == 1)

            assert cond1
            assert cond2
            assert cond3
            assert cond4
            assert cond5

        except Exception as e:
            print(e)
            raise AssertionError("One of the check conditions has failed.")

    def _subset_to_balanced(X, y, mask, ids):
        """Subset samples so dataset is more well sampled."""
        class_numbers = np.sum(y, axis=0)
        largest_class, target_num_samples = np.argmax(
            class_numbers), np.sort(class_numbers)[-2]
        print("\nSubsetting class {} from {} to {} samples.".format(largest_class, class_numbers[largest_class],
                                                                    target_num_samples))

        # Select random
        largest_class_ids = np.arange(y.shape[0])[y[:, largest_class] == 1]
        class_ids_samples = np.random.choice(
            largest_class_ids, size=target_num_samples, replace=False)
        ids_to_remove_ = np.setdiff1d(largest_class_ids, class_ids_samples)

        # Remove relevant ids
        X_out = np.delete(X, ids_to_remove_, axis=0)
        y_out = np.delete(y, ids_to_remove_, axis=0)
        mask_out = np.delete(mask, ids_to_remove_, axis=0)
        ids_out = np.delete(ids, ids_to_remove_, axis=0)

        return X_out, y_out, mask_out, ids_out

    def _add_time_to_end(self, X):
        """Add new column to dataframe - this computes time to end of grouped observations, if needed."""
        x_inter = X.copy(deep=True)

        # if time to end has not been computed
        if self.needs_time_to_end_computation is True:

            # Compute datetime values for time until end of group of observations
            times = X.groupby(self.id_col).apply(
                lambda x: x.loc[:, self.time_col].max() - x.loc[:, self.time_col])

            # add column to dataframe after converting to hourly times.
            x_inter["time_to_end"] = self.convert_datetime_to_hour(
                times).values

        else:
            x_inter["time_to_end"] = x_inter[self.time_col].values
            x_inter["time_to_end"] = self.convert_datetime_to_hour(
                x_inter.loc[:, "time_to_end"])

        # Sort data
        self.time_col = "time_to_end"
        x_out = x_inter.sort_values(
            by=[self.id_col, "time_to_end"], ascending=[True, False])

        return x_out

    def _truncate(self, X):
        """Truncate dataset on time to end column according to self.time_range."""
        try:
            min_time, max_time = self.time_range
            # print(self.time_range)
            return X[X['time_to_end'].between(min_time, max_time, inclusive="left")]

        except Exception:
            raise ValueError(
                f"Could not truncate to {self.time_range} time range successfully")

    def _check_correct_time_conversion(self, X):
        """Check addition and truncation of time index worked accordingly."""

        cond1 = X[self.id_col].is_monotonic_increasing
        cond2 = X.groupby(self.id_col).apply(
            lambda x: x["time_to_end"].is_monotonic_decreasing).all()

        min_time, max_time = self.time_range
        cond3 = X["time_to_end"].between(
            min_time, max_time, inclusive='left').all()

        assert cond1 is True
        assert cond2 == True
        assert cond3 == True

    def subset_to_features(self, X):
        """Subset only to variables which were selected"""
        features = [self.id_col, "time_to_end"] + \
            self._get_features(self.feat_set, self.data_name)

        return X[features], features

    def convert_to_3darray(self, X):
        """Convert a pandas dataframe to 3D numpy array of shape (num_samples, num_timestamps, num_variables)."""

        # Obtain relevant shape sizes
        max_time_length = X.groupby(self.id_col).count()["time_to_end"].max()
        num_ids = X[self.id_col].nunique()

        # Other basic definitions
        feats = [col for col in X.columns if col not in [
            self.id_col, "time_to_end"]]
        list_ids = X[self.id_col].unique()

        # Initialise output array and id-time array
        out_array = np.empty(shape=(num_ids, max_time_length, len(feats)))
        out_array[:] = np.nan

        # Make a parallel array indicating id and corresponding time
        id_times_array = np.empty(shape=(num_ids, max_time_length, 2))

        # Set ids in this newly generated array
        id_times_array[:, :, 0] = np.repeat(np.expand_dims(
            list_ids, axis=-1), repeats=max_time_length, axis=-1)

        # Iterate through ids
        for id_ in tqdm(list_ids):
            # Subset data to where matches respective id
            index_ = np.where(list_ids == id_)[0]
            x_id = X[X[self.id_col] == id_]

            # Compute negative differences instead of keeping the original times.
            x_id_copy = x_id.copy()
            x_id_copy["time_to_end"] = - x_id["time_to_end"].diff().values

            # Update target output array and time information array
            out_array[index_, :x_id_copy.shape[0], :] = x_id_copy[feats].values
            id_times_array[index_, :x_id_copy.shape[0],
                           1] = x_id["time_to_end"].values

        return out_array.astype("float32"), id_times_array.astype("float32")

    def normalise(self, X):
        """Given 3D array, normalise according to min-max method."""
        self.min = np.nanmin(X, axis=0, keepdims=True)
        self.max = np.nanmax(X, axis=0, keepdims=True)

        return np.divide(X - self.min, self.max - self.min)

    def apply_normalisation(self, X):
        """Apply normalisation with current parameters to another dataset."""
        if self.min is None or self.max is None:
            raise ValueError(
                f"Attributes min and/or max are not yet computed. Run 'normalise' method instead.")

        else:
            return np.divide(X - self.min, self.max - self.min)


# Custom Dataloader
def collate_fn(data):
    x, y, mask, pat_time_ids, features, outcomes, x_subset, y_data, id_col, time_col, needs_time_to_end_computation, data_name, feat_set, time_range, target_window, min, max = zip(
        *data)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data_config = {"data_name": data_name, "feat_set": feat_set,
                   "time_range (h)": time_range, "target_window": target_window}
    data_properties = {"feats": features, "id_col": id_col, "time_col": time_col,
                       "norm_min": min, "norm_max": max, "outc_names": outcomes}

    x = torch.tensor(np.array(x))
    y = torch.tensor(np.array(y))
    x = x.to(device)
    y = y.to(device)
    # mask = torch.tensor(mask)
    # pat_time_ids = torch.tensor(pat_time_ids)
    # x_subset = torch.tensor(x_subset)
    # y_data = torch.tensor(y_data)

    return x, y


def load_data(train_dataset, val_dataset, test_dataset):
    """
    Return a DataLoader instance basing on a Dataset instance, with batch_size specified.
    set shuffle=???
    """

    batch_size = 64
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    return train_loader, val_loader, test_loader

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class FeatTimeAttention(nn.Module):
    def __init__(self, latent_dim, input_shape):
        super().__init__()

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.latent_dim = latent_dim
        T, D_f = input_shape
        # Define Kernel and Bias for Feature Projection
        self.kernel = torch.zeros(
            (1, 1, D_f, self.latent_dim), requires_grad=True).to(self.device)
        nn.init.xavier_uniform_(self.kernel)
        self.bias = torch.zeros(
            (1, 1, D_f, self.latent_dim), requires_grad=True).to(self.device)
        nn.init.uniform_(self.bias)

        # Define Time aggregation weights for averaging over time.
        self.unnorm_beta = torch.zeros((1, T, 1), requires_grad=True)
        nn.init.uniform_(self.unnorm_beta)

    def forward(self, x, latent):
        o_hat, _ = self.generate_latent_approx(x, latent)
        weights = self.calc_weights(self.unnorm_beta)
        # print(o_hat.shape, weights.shape)
        return torch.sum(torch.mul(o_hat.to(self.device), weights.to(self.device)), dim=1)

    def generate_latent_approx(self, x, latent):
        features = torch.mul(x.unsqueeze(-1), self.kernel) + self.bias
        features = F.relu(features)

        # calculate the score
        X_T, X = features, features.transpose(2, 3)
        # print(X_T.shape, X.shape)
        X_T_X_inv = torch.inverse(torch.matmul(X_T, X))
        # print(X_T.shape, latent.unsqueeze(-1).shape)
        X_T_y = torch.matmul(X_T, latent.unsqueeze(-1))

        score_hat = torch.matmul(X_T_X_inv, X_T_y)
        scores = torch.squeeze(score_hat)

        # print(scores.unsqueeze(-1).shape, features.shape)
        o_hat = torch.sum(torch.mul(scores.unsqueeze(-1), features), dim=2)

        return o_hat, scores

    def calc_weights(self, x):
        abs_x = torch.abs(x)
        return abs_x / torch.sum(abs_x, dim=1)


class Encoder(nn.Module):
    def __init__(self, input_shape, attention_hidden_dim, latent_dim, dropout):
        super().__init__()
        self.lstm1 = nn.LSTM(input_size=input_shape[1],
                             hidden_size=attention_hidden_dim,
                             num_layers=2,
                             dropout=dropout,
                             batch_first=True)
        self.lstm2 = nn.LSTM(input_size=attention_hidden_dim,
                             hidden_size=latent_dim,
                             num_layers=1,
                             batch_first=True)
        self.attention = FeatTimeAttention(latent_dim, input_shape)

    def forward(self, x):
        latent_rep, _ = self.lstm1(x)
        latent_rep, _ = self.lstm2(latent_rep)
        output = self.attention(x, latent_rep)
        return output


class Identifier(nn.Module):
    def __init__(self, input_dim, mlp_hidden_dim, dropout, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, mlp_hidden_dim)
        self.sigmoid1 = nn.Sigmoid()

        self.fc2 = nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        self.sigmoid2 = nn.Sigmoid()
        self.dropout1 = nn.Dropout(dropout)

        self.fc4 = nn.Linear(mlp_hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.sigmoid1(x)

        x = self.fc2(x)
        x = self.sigmoid2(x)
        x = self.dropout1(x)

        x = self.fc4(x)
        x = self.softmax(x)
        return x


class Predictor(nn.Module):
    def __init__(self, input_dim, mlp_hidden_dim, dropout, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, mlp_hidden_dim)
        self.sigmoid1 = nn.Sigmoid()

        self.fc2 = nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        self.sigmoid2 = nn.Sigmoid()
        self.dropout1 = nn.Dropout(dropout)

        self.fc3 = nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        self.sigmoid3 = nn.Sigmoid()
        self.dropout2 = nn.Dropout(dropout)

        self.fc4 = nn.Linear(mlp_hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.sigmoid1(x)

        x = self.fc2(x)
        x = self.sigmoid2(x)
        x = self.dropout1(x)

        x = self.fc3(x)
        x = self.sigmoid3(x)
        x = self.dropout2(x)

        x = self.fc4(x)
        x = self.softmax(x)
        return x


class MyLRScheduler():
    def __init__(self, optimizer, patience, min_lr, factor):
        self.optimizer = optimizer
        self.patience = patience
        self.min_lr = min_lr
        self.factor = factor
        self.wait = 0
        self.best_loss = float('inf')

    def step(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.wait = 0
                for param_group in self.optimizer.param_groups:
                    old_lr = param_group['lr']
                    new_lr = max(old_lr * self.factor, self.min_lr)
                    param_group['lr'] = new_lr


def calc_l1_l2_loss(part=None, layers=None):
    para = []
    if part:
        for parameter in part.parameters():
            para.append(parameter.view(-1))
        parameters = torch.cat(para)
    if layers:
        for layer in layers:
            para.extend(layer.parameters())
        parameters = torch.cat([p.view(-1) for p in para])
    return 1e-30 * torch.abs(parameters).sum() + 1e-30 * torch.square(parameters).sum()


In [5]:
import numpy as np
import torch


def np_log(x):
    return np.log(x + 1e-8)


def torch_log(x):
    return torch.log(x + 1e-8)


def calc_pred_loss(y_true, y_pred, weights=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if weights is None:
        weights = torch.ones(y_true.shape) / y_true.shape[-1]
    return - torch.mean(torch.sum(weights.to(device) * y_true.to(device) * torch_log(y_pred).to(device), axis=-1))


def calc_dist_loss(probs):
    avg_prob = torch.mean(probs, dim=-1)
    return torch.sum(avg_prob * torch.log(avg_prob))


def calc_clus_loss(clusters):
    pairewise_loss = - \
        torch.sum((clusters.unsqueeze(1) - clusters.unsqueeze(0)) ** 2, dim=-1)
    loss = torch.sum(pairewise_loss)

    K = clusters.shape[0]
    return (loss / (K * (K-1) / 2)).float()


In [7]:
from sklearn.cluster import KMeans
import numpy as np
from tqdm import trange

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset



SEED = 12345
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def class_weight(y):
    class_numbers = torch.sum(y, dim=0)

    # Check no class is missing
    if not torch.all(class_numbers > 0):
        class_numbers += 1
    inv_class_num = 1 / class_numbers
    return inv_class_num / torch.sum(inv_class_num)


class CamelotModel(nn.Module):
    def __init__(self, input_shape, num_clusters=10, latent_dim=128, seed=SEED, output_dim=4,
                 alpha=0.01, beta=0.001, regularization=(0.01, 0.01), dropout=0.0,
                 cluster_rep_lr=0.001, weighted_loss=True, attention_hidden_dim=16,
                 mlp_hidden_dim=30):

        super().__init__()
        self.seed = seed

        self.input_shape = input_shape
        self.num_clusters = num_clusters
        self.latent_dim = latent_dim
        self.output_dim = output_dim
        self.alpha = alpha
        self.beta = beta
        self.regularization = regularization
        self.dropout = dropout
        self.cluster_rep_lr = cluster_rep_lr
        self.weighted_loss = weighted_loss
        self.attention_hidden_dim = attention_hidden_dim
        self.mlp_hidden_dim = mlp_hidden_dim

        # three newtorks
        self.Encoder = Encoder(
            self.input_shape, self.attention_hidden_dim, self.latent_dim, self.dropout)
        self.Identifier = Identifier(
            self.latent_dim, self.mlp_hidden_dim, self.dropout, self.num_clusters)
        self.Predictor = Predictor(
            self.latent_dim, self.mlp_hidden_dim, self.dropout, self.output_dim)

        # Cluster Representation params
        self.cluster_rep_set = torch.zeros(
            size=[self.num_clusters, self.latent_dim], dtype=torch.float32, requires_grad=True)

        self.loss_weights = None

    def forward(self, x):
        z = self.Encoder(x)
        probs = self.Identifier(z)
        samples = self.get_sample(probs)
        representations = self.get_representations(samples)
        return self.Predictor(representations)

    def forward_pass(self, x):
        z = self.Encoder(x)
        probs = self.Identifier(z)
        # print(probs.shape)
        # samples = self.get_sample(probs)
        # # print(samples.shape)
        # representations = self.get_representations(samples)
        # # print(representations.shape)
        # return self.Predictor(representations), probs
        clus_phens = self.Predictor(self.cluster_rep_set.to(device))
        y_pred = torch.matmul(probs, clus_phens)

        return y_pred, probs

    def get_sample(self, probs):
        logits = - torch.log(probs.reshape(-1, self.num_clusters))
        samples = torch.multinomial(logits, num_samples=1)
        return samples.squeeze()

    def get_representations(self, samples):
        mask = F.one_hot(samples, num_classes=self.num_clusters).to(
            torch.float32)
        return torch.matmul(mask.to(device), self.cluster_rep_set.to(device))

    def calc_pis(self, X):
        return self.Identifier(self.Encoder(X)).numpy()

    def get_cluster_reps(self):
        return self.cluster_rep_set.numpy()

    def assignment(self, X):
        pi = self.Identifier(self.Encoder(X)).numpy()
        return torch.argmax(pi, dim=1)

    def compute_cluster_phenotypes(self):
        return self.Predictor(self.cluster_rep_set).numpy()

    # def compute_unnorm_attention_weights(self, inputs):
    #     # no idea
    #     return self.Encoder.compute_unnorm_scores(inputs, cluster_reps=self.cluster_rep_set)

    # def compute_norm_attention_weights(self, inputs):
    #     # no idea
    #     return self.Encoder.compute_norm_scores(inputs, cluster_reps=self.cluster_rep_set)

    def initialize(self, train_data, val_data):
        x_train, y_train = train_data
        x_val, y_val = val_data
        self.loss_weights = class_weight(y_train)

        # initialize encoder
        self.initialize_encoder(x_train, y_train, x_val, y_val)

        # initialize cluster
        clus_train, clus_val = self.initialize_cluster(x_train, x_val)
        self.clus_train = clus_train
        self.x_train = x_train

        # initialize identifier
        self.initialize_identifier(x_train, clus_train, x_val, clus_val)

    def initialize_encoder(self, x_train, y_train, x_val, y_val, epochs=100, batch_size=64):
        temp = DataLoader(
            dataset=TensorDataset(x_train, y_train),
            shuffle=True,
            batch_size=batch_size
        )

        iden_loss = torch.full((epochs,), float('nan'))
        initialize_optim = torch.optim.Adam(
            self.Encoder.parameters(), lr=0.001)

        for i in trange(epochs):
            epoch_loss = 0
            for _, (x_batch, y_batch) in enumerate(temp):
                initialize_optim.zero_grad()

                z = self.Encoder(x_batch)
                y_pred = self.Predictor(z)
                loss = calc_pred_loss(
                    y_batch, y_pred, self.loss_weights) + calc_l1_l2_loss(part=self.Encoder)

                loss.backward()
                initialize_optim.step()

                epoch_loss += loss.item()

            with torch.no_grad():
                z = self.Encoder(x_val)
                y_pred_val = self.Predictor(z)
                loss_val = calc_pred_loss(y_val, y_pred_val, self.loss_weights)

            iden_loss[i] = loss_val.item()
            if torch.le(iden_loss[-50:], loss_val.item() + 0.001).any():
                break

        print('Encoder initialization done!')

    def initialize_cluster(self, x_train, x_val):
        z = self.Encoder(x_train).cpu().detach().numpy()
        kmeans = KMeans(self.num_clusters, random_state=self.seed, n_init='auto')
        kmeans.fit(z)
        print('Kmeans initialization done!')

        self.cluster_rep_set = torch.tensor(
            kmeans.cluster_centers_, dtype=torch.float32, requires_grad=True)
        train_cluster = torch.eye(self.num_clusters)[
            kmeans.predict(z)]
        val_cluster = torch.eye(self.num_clusters)[kmeans.predict(
            self.Encoder(x_val).cpu().detach().numpy())]

        print('Cluster initialization done!')
        return train_cluster, val_cluster

    def initialize_identifier(self, x_train, clus_train, x_val, clus_val, epochs=100, batch_size=64):
        temp = DataLoader(
            dataset=TensorDataset(x_train, clus_train),
            shuffle=True,
            batch_size=batch_size
        )

        iden_loss = torch.full((epochs,), float('nan'))
        initialize_optim = torch.optim.Adam(
            self.Identifier.parameters(), lr=0.001)

        for i in trange(epochs):
            epoch_loss = 0
            for step_, (x_batch, clus_batch) in enumerate(temp):
                initialize_optim.zero_grad()

                clus_pred = self.Identifier(self.Encoder(x_batch))
                loss = calc_pred_loss(clus_batch, clus_pred) + \
                    calc_l1_l2_loss(layers=[self.Identifier.fc2])

                loss.backward()
                initialize_optim.step()

                epoch_loss += loss.item()

            with torch.no_grad():
                clus_pred_val = self.Identifier(self.Encoder(x_val))
                loss_val = calc_pred_loss(clus_val, clus_pred_val)

            iden_loss[i] = loss_val.item()
            if torch.le(iden_loss[-50:], loss_val.item() + 0.001).any():
                break

        print('Identifier initialization done!')


In [13]:
results = np.zeros((len(seeds), 4))
for index, SEED in enumerate(seeds):
    torch.random.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = CustomDataset(time_range=(0, 10))

    # Stratified Sampling for train and val
    train_idx, test_idx = train_test_split(np.arange(len(dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(dataset.y,axis=-1))

    # Subset dataset for train and val
    train_val_dataset = dataset.get_subset(train_idx)
    test_dataset = dataset.get_subset(test_idx)

    train_idx,  val_idx = train_test_split(np.arange(len(train_val_dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(train_val_dataset.y,axis=-1))

    train_dataset = train_val_dataset.get_subset(train_idx)
    val_dataset = train_val_dataset.get_subset(val_idx)

    train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset)

    model = CamelotModel(input_shape=(train_dataset.x.shape[1], train_dataset.x.shape[2]), seed=SEED, num_clusters=10, latent_dim=64)
    model = model.to(device)

    train_x = torch.tensor(train_dataset.x).to(device)
    train_y = torch.tensor(train_dataset.y).to(device)
    val_x = torch.tensor(val_dataset.x).to(device)
    val_y = torch.tensor(val_dataset.y).to(device)

    model.initialize((train_x, train_y), (val_x, val_y))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    cluster_optim = torch.optim.Adam([model.cluster_rep_set], lr=0.001)

    lr_scheduler = MyLRScheduler(optimizer, patience=15, min_lr=0.00001, factor=0.25)
    cluster_lr_scheduler = MyLRScheduler(cluster_optim, patience=15, min_lr=0.00001, factor=0.25)

    loss_mat = np.zeros((100, 4, 2))

    best_loss = 1e5
    count = 0
    for i in trange(100):
        for step, (x_train, y_train) in enumerate(train_loader):
            optimizer.zero_grad()
            cluster_optim.zero_grad()

            y_pred, probs = model.forward_pass(x_train)

            loss_weights = class_weight(y_train)

            common_loss = calc_pred_loss(y_train, y_pred, loss_weights)

            enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
             + calc_l1_l2_loss(part=model.Encoder) 
            enc_loss.backward(retain_graph=True, inputs=list(model.Encoder.parameters()))

            idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(layers=[model.Identifier.fc2])
            idnetifier_loss.backward(retain_graph=True, inputs=list(model.Identifier.parameters()))

            pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])
            pred_loss.backward(retain_graph=True, inputs=list(model.Predictor.parameters()))

            clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
            clus_loss.backward(inputs=model.cluster_rep_set)

            optimizer.step()
            cluster_optim.step()

            loss_mat[i, 0, 0] += enc_loss.item()
            loss_mat[i, 1, 0] += idnetifier_loss.item()
            loss_mat[i, 2, 0] += pred_loss.item()
            loss_mat[i, 3, 0] += clus_loss.item()

        with torch.no_grad():
            for step, (x_val, y_val) in enumerate(val_loader):
                y_pred, probs = model.forward_pass(x_val)

                loss_weights = class_weight(y_val)

                common_loss = calc_pred_loss(y_val, y_pred, loss_weights)

                enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                 + calc_l1_l2_loss(part=model.Encoder) 

                idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(layers=[model.Identifier.fc2])

                pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])

                clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)

                loss_mat[i, 0, 1] += enc_loss.item()
                loss_mat[i, 1, 1] += idnetifier_loss.item()
                loss_mat[i, 2, 1] += pred_loss.item()
                loss_mat[i, 3, 1] += clus_loss.item()

            if i >= 30:
                if loss_mat[i, 0, 1] < best_loss:
                    count = 0
                    best_loss = loss_mat[i, 0, 1]
                    torch.save(model.state_dict(), './best_model')
                else:
                    count += 1
                    if count >= 50:
                        model.load_state_dict(torch.load('./best_model'))
        lr_scheduler.step(loss_mat[i, 0, 1])
        cluster_lr_scheduler.step(loss_mat[i, 0, 1])

#     print(calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3]), calc_l1_l2_loss(part=model.Encoder) + calc_l1_l2_loss(layers=[model.Identifier.fc2]))

    model.load_state_dict(torch.load('./best_model'))

    real, preds = [], []
    with torch.no_grad():
        for _, (x, y) in enumerate(test_loader):
            y_pred, _ = model.forward_pass(x)
            preds.extend(list(y_pred.cpu().detach().numpy()))
            real.extend(list(y.cpu().detach().numpy()))

    auc = roc_auc_score(real, preds, average=None)

    labels_true, labels_pred = np.argmax(real, axis=1), np.argmax(preds, axis=1)

    # Compute F1
    f1 = f1_score(labels_true, labels_pred, average=None)

    # Compute Recall
    rec = recall_score(labels_true, labels_pred, average=None)

    # Compute NMI
    nmi = normalized_mutual_info_score(labels_true, labels_pred)

    print(f'AUCROC: \t{auc.mean():.5f}, \t{auc}')
    print(f'F1-score: \t{f1.mean():.5f}, \t{f1}')
    print(f'Recall: \t{rec.mean():.5f}, \t{rec}')
    print(f'NMI: \t\t{nmi:.5f}')
    
    results[index, 0] = auc.mean()
    results[index, 1] = f1.mean()
    results[index, 2] = rec.mean()
    results[index, 3] = nmi


MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:12<00:00, 632.52it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.58it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:14<00:14,  3.46it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:56<00:00,  1.78it/s]


AUCROC: 	0.77337, 	[0.87801372 0.78423927 0.75752452 0.67370918]
F1-score: 	0.28072, 	[0.         0.47262248 0.65024631 0.        ]
Recall: 	0.34648, 	[0.         0.8803681  0.50553191 0.        ]
NMI: 		0.09244

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:13<00:00, 587.82it/s]
 50%|█████     | 50/100 [00:15<00:15,  3.15it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:15<00:15,  3.31it/s]


Identifier initialization done!


100%|██████████| 100/100 [01:05<00:00,  1.52it/s]


AUCROC: 	0.78254, 	[0.88453937 0.78543456 0.74934337 0.71084645]
F1-score: 	0.33899, 	[0.         0.50728863 0.84865209 0.        ]
Recall: 	0.34610, 	[0.         0.53374233 0.8506383  0.        ]
NMI: 		0.11785

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:12<00:00, 630.28it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.55it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.91it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.89it/s]


AUCROC: 	0.79506, 	[0.8972068  0.79897557 0.76723754 0.71681118]
F1-score: 	0.34045, 	[0.         0.52589641 0.83591872 0.        ]
Recall: 	0.35535, 	[0.         0.60736196 0.81404255 0.        ]
NMI: 		0.11889

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 727.86it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.50it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.91it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


AUCROC: 	0.76984, 	[0.88031689 0.79648584 0.77335099 0.62918821]
F1-score: 	0.31292, 	[0.         0.50793651 0.74375624 0.        ]
Recall: 	0.36086, 	[0.         0.80981595 0.63361702 0.        ]
NMI: 		0.10792

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 724.27it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.57it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.90it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.89it/s]


AUCROC: 	0.77188, 	[0.90392029 0.77566572 0.75097826 0.65695633]
F1-score: 	0.33523, 	[0.         0.49857955 0.84235294 0.        ]
Recall: 	0.34405, 	[0.         0.53834356 0.83787234 0.        ]
NMI: 		0.10775

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 697.72it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.66it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.77it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.87it/s]


AUCROC: 	0.70530, 	[0.79433192 0.72972417 0.71580348 0.58132733]
F1-score: 	0.29846, 	[0.         0.4903975  0.70343392 0.        ]
Recall: 	0.35434, 	[0.         0.84202454 0.57531915 0.        ]
NMI: 		0.10191

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 722.75it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.65it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.74it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


AUCROC: 	0.78450, 	[0.91149951 0.78422632 0.75711587 0.68514229]
F1-score: 	0.29423, 	[0.         0.48701013 0.68992655 0.        ]
Recall: 	0.35193, 	[0.         0.84815951 0.55957447 0.        ]
NMI: 		0.10370

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 710.62it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.64it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.77it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


AUCROC: 	0.77756, 	[0.91736361 0.77453451 0.76048549 0.6578481 ]
F1-score: 	0.30760, 	[0.         0.50094162 0.72947714 0.        ]
Recall: 	0.35760, 	[0.         0.81595092 0.61446809 0.        ]
NMI: 		0.10547

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 689.38it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.68it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.72it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.86it/s]


AUCROC: 	0.78638, 	[0.91034793 0.79900083 0.77334488 0.66282572]
F1-score: 	0.34174, 	[0.         0.53725736 0.82969238 0.        ]
Recall: 	0.36247, 	[0.         0.65797546 0.79191489 0.        ]
NMI: 		0.12944

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 696.55it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.63it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.75it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.89it/s]

AUCROC: 	0.76785, 	[0.90296472 0.77128549 0.74813051 0.64903701]
F1-score: 	0.32843, 	[0.         0.50537634 0.80834656 0.        ]
Recall: 	0.35177, 	[0.         0.64877301 0.75829787 0.        ]
NMI: 		0.10231





In [14]:
for m, u, std in zip(metrics, results.mean(axis=0), results.std(axis=0)):
    print(f'{m}: {u:.3f} ({std:.3f})')

AUC: 0.771 (0.023)
F1 score: 0.318 (0.021)
Recall: 0.353 (0.006)
NMI: 0.109 (0.010)


In [15]:
results = np.zeros((len(seeds), 4))
for index, SEED in enumerate(seeds):
    torch.random.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = CustomDataset(time_range=(0, 10))

    # Stratified Sampling for train and val
    train_idx, test_idx = train_test_split(np.arange(len(dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(dataset.y,axis=-1))

    # Subset dataset for train and val
    train_val_dataset = dataset.get_subset(train_idx)
    test_dataset = dataset.get_subset(test_idx)

    train_idx,  val_idx = train_test_split(np.arange(len(train_val_dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(train_val_dataset.y,axis=-1))

    train_dataset = train_val_dataset.get_subset(train_idx)
    val_dataset = train_val_dataset.get_subset(val_idx)

    train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset)

    model = CamelotModel(input_shape=(train_dataset.x.shape[1], train_dataset.x.shape[2]), seed=SEED, num_clusters=10, latent_dim=64, beta=0)
    model = model.to(device)

    train_x = torch.tensor(train_dataset.x).to(device)
    train_y = torch.tensor(train_dataset.y).to(device)
    val_x = torch.tensor(val_dataset.x).to(device)
    val_y = torch.tensor(val_dataset.y).to(device)

    model.initialize((train_x, train_y), (val_x, val_y))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    cluster_optim = torch.optim.Adam([model.cluster_rep_set], lr=0.001)

    lr_scheduler = MyLRScheduler(optimizer, patience=15, min_lr=0.00001, factor=0.25)
    cluster_lr_scheduler = MyLRScheduler(cluster_optim, patience=15, min_lr=0.00001, factor=0.25)

    loss_mat = np.zeros((100, 4, 2))

    best_loss = 1e5
    count = 0
    for i in trange(100):
        for step, (x_train, y_train) in enumerate(train_loader):
            optimizer.zero_grad()
            cluster_optim.zero_grad()

            y_pred, probs = model.forward_pass(x_train)

            loss_weights = class_weight(y_train)

            common_loss = calc_pred_loss(y_train, y_pred, loss_weights)

            enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
             + calc_l1_l2_loss(part=model.Encoder) 
            enc_loss.backward(retain_graph=True, inputs=list(model.Encoder.parameters()))

            idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(layers=[model.Identifier.fc2])
            idnetifier_loss.backward(retain_graph=True, inputs=list(model.Identifier.parameters()))

            pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])
            pred_loss.backward(retain_graph=True, inputs=list(model.Predictor.parameters()))

            clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
            clus_loss.backward(inputs=model.cluster_rep_set)

            optimizer.step()
            cluster_optim.step()

            loss_mat[i, 0, 0] += enc_loss.item()
            loss_mat[i, 1, 0] += idnetifier_loss.item()
            loss_mat[i, 2, 0] += pred_loss.item()
            loss_mat[i, 3, 0] += clus_loss.item()

        with torch.no_grad():
            for step, (x_val, y_val) in enumerate(val_loader):
                y_pred, probs = model.forward_pass(x_val)

                loss_weights = class_weight(y_val)

                common_loss = calc_pred_loss(y_val, y_pred, loss_weights)

                enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                 + calc_l1_l2_loss(part=model.Encoder) 

                idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(layers=[model.Identifier.fc2])

                pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])

                clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)

                loss_mat[i, 0, 1] += enc_loss.item()
                loss_mat[i, 1, 1] += idnetifier_loss.item()
                loss_mat[i, 2, 1] += pred_loss.item()
                loss_mat[i, 3, 1] += clus_loss.item()

            if i >= 30:
                if loss_mat[i, 0, 1] < best_loss:
                    count = 0
                    best_loss = loss_mat[i, 0, 1]
                    torch.save(model.state_dict(), './best_model')
                else:
                    count += 1
                    if count >= 50:
                        model.load_state_dict(torch.load('./best_model'))
        lr_scheduler.step(loss_mat[i, 0, 1])
        cluster_lr_scheduler.step(loss_mat[i, 0, 1])

#     print(calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3]), calc_l1_l2_loss(part=model.Encoder) + calc_l1_l2_loss(layers=[model.Identifier.fc2]))

    model.load_state_dict(torch.load('./best_model'))

    real, preds = [], []
    with torch.no_grad():
        for _, (x, y) in enumerate(test_loader):
            y_pred, _ = model.forward_pass(x)
            preds.extend(list(y_pred.cpu().detach().numpy()))
            real.extend(list(y.cpu().detach().numpy()))

    auc = roc_auc_score(real, preds, average=None)

    labels_true, labels_pred = np.argmax(real, axis=1), np.argmax(preds, axis=1)

    # Compute F1
    f1 = f1_score(labels_true, labels_pred, average=None)

    # Compute Recall
    rec = recall_score(labels_true, labels_pred, average=None)

    # Compute NMI
    nmi = normalized_mutual_info_score(labels_true, labels_pred)

    print(f'AUCROC: \t{auc.mean():.5f}, \t{auc}')
    print(f'F1-score: \t{f1.mean():.5f}, \t{f1}')
    print(f'Recall: \t{rec.mean():.5f}, \t{rec}')
    print(f'NMI: \t\t{nmi:.5f}')
    
    results[index, 0] = auc.mean()
    results[index, 1] = f1.mean()
    results[index, 2] = rec.mean()
    results[index, 3] = nmi


MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 662.78it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.49it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.88it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


AUCROC: 	0.76733, 	[0.86754329 0.77836003 0.75103647 0.67239117]
F1-score: 	0.33383, 	[0.         0.49459265 0.8407155  0.        ]
Recall: 	0.34152, 	[0.         0.52607362 0.84       0.        ]
NMI: 		0.10798

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 726.57it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.65it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.80it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.86it/s]


AUCROC: 	0.77330, 	[0.8613035  0.78128133 0.74634194 0.7042872 ]
F1-score: 	0.28195, 	[0.         0.47196653 0.65583536 0.        ]
Recall: 	0.34509, 	[0.         0.86503067 0.51531915 0.        ]
NMI: 		0.09085

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 709.26it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.64it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.75it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:54<00:00,  1.84it/s]


AUCROC: 	0.79171, 	[0.89135903 0.79909175 0.76810752 0.70826369]
F1-score: 	0.33788, 	[0.         0.51111111 0.84040491 0.        ]
Recall: 	0.34866, 	[0.         0.56441718 0.83021277 0.        ]
NMI: 		0.11117

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 718.98it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.65it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.75it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.92it/s]


AUCROC: 	0.75960, 	[0.87720516 0.76079271 0.77068894 0.62970981]
F1-score: 	0.32774, 	[0.         0.52421959 0.78674556 0.        ]
Recall: 	0.36354, 	[0.         0.74693252 0.70723404 0.        ]
NMI: 		0.11623

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 708.20it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.66it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.80it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


AUCROC: 	0.76947, 	[0.91386802 0.76688569 0.74496173 0.65214977]
F1-score: 	0.32637, 	[0.         0.50293772 0.80255649 0.        ]
Recall: 	0.35113, 	[0.         0.65644172 0.74808511 0.        ]
NMI: 		0.10022

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 691.61it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.64it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.75it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


AUCROC: 	0.75183, 	[0.90486769 0.74914568 0.73146433 0.62185218]
F1-score: 	0.28803, 	[0.         0.46264626 0.68945869 0.        ]
Recall: 	0.33868, 	[0.         0.78834356 0.56638298 0.        ]
NMI: 		0.07525

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 702.60it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.65it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.82it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


AUCROC: 	0.74883, 	[0.83957857 0.75443043 0.73177082 0.66955322]
F1-score: 	0.33783, 	[0.         0.49573974 0.85559265 0.        ]
Recall: 	0.34078, 	[0.         0.49079755 0.87234043 0.        ]
NMI: 		0.11604

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 700.60it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.66it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.91it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.89it/s]


AUCROC: 	0.73254, 	[0.85329958 0.73570033 0.71219897 0.6289807 ]
F1-score: 	0.26343, 	[0.         0.45153756 0.60216278 0.        ]
Recall: 	0.33495, 	[0.         0.88957055 0.45021277 0.        ]
NMI: 		0.07883

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 707.57it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.67it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.91it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.89it/s]


AUCROC: 	0.78649, 	[0.92156158 0.78833882 0.76922461 0.66684708]
F1-score: 	0.34197, 	[0.         0.53537285 0.83252105 0.        ]
Recall: 	0.36094, 	[0.         0.64417178 0.79957447 0.        ]
NMI: 		0.13199

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 726.08it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.49it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.90it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]

AUCROC: 	0.77283, 	[0.88434335 0.78302534 0.74520563 0.67876533]
F1-score: 	0.33541, 	[0.         0.50140845 0.84023161 0.        ]
Recall: 	0.34491, 	[0.         0.54601227 0.83361702 0.        ]
NMI: 		0.10777





In [16]:
for m, u, std in zip(metrics, results.mean(axis=0), results.std(axis=0)):
    print(f'{m}: {u:.3f} ({std:.3f})')

AUC: 0.765 (0.017)
F1 score: 0.317 (0.027)
Recall: 0.347 (0.009)
NMI: 0.104 (0.017)


In [17]:
results = np.zeros((len(seeds), 4))
for index, SEED in enumerate(seeds):
    torch.random.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = CustomDataset(time_range=(0, 10))

    # Stratified Sampling for train and val
    train_idx, test_idx = train_test_split(np.arange(len(dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(dataset.y,axis=-1))

    # Subset dataset for train and val
    train_val_dataset = dataset.get_subset(train_idx)
    test_dataset = dataset.get_subset(test_idx)

    train_idx,  val_idx = train_test_split(np.arange(len(train_val_dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(train_val_dataset.y,axis=-1))

    train_dataset = train_val_dataset.get_subset(train_idx)
    val_dataset = train_val_dataset.get_subset(val_idx)

    train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset)

    model = CamelotModel(input_shape=(train_dataset.x.shape[1], train_dataset.x.shape[2]), seed=SEED, num_clusters=10, latent_dim=64, beta=0, alpha=0)
    model = model.to(device)

    train_x = torch.tensor(train_dataset.x).to(device)
    train_y = torch.tensor(train_dataset.y).to(device)
    val_x = torch.tensor(val_dataset.x).to(device)
    val_y = torch.tensor(val_dataset.y).to(device)

    model.initialize((train_x, train_y), (val_x, val_y))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    cluster_optim = torch.optim.Adam([model.cluster_rep_set], lr=0.001)

    lr_scheduler = MyLRScheduler(optimizer, patience=15, min_lr=0.00001, factor=0.25)
    cluster_lr_scheduler = MyLRScheduler(cluster_optim, patience=15, min_lr=0.00001, factor=0.25)

    loss_mat = np.zeros((100, 4, 2))

    best_loss = 1e5
    count = 0
    for i in trange(100):
        for step, (x_train, y_train) in enumerate(train_loader):
            optimizer.zero_grad()
            cluster_optim.zero_grad()

            y_pred, probs = model.forward_pass(x_train)

            loss_weights = class_weight(y_train)

            common_loss = calc_pred_loss(y_train, y_pred, loss_weights)

            enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
             + calc_l1_l2_loss(part=model.Encoder) 
            enc_loss.backward(retain_graph=True, inputs=list(model.Encoder.parameters()))

            idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(layers=[model.Identifier.fc2])
            idnetifier_loss.backward(retain_graph=True, inputs=list(model.Identifier.parameters()))

            pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])
            pred_loss.backward(retain_graph=True, inputs=list(model.Predictor.parameters()))

            clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
            clus_loss.backward(inputs=model.cluster_rep_set)

            optimizer.step()
            cluster_optim.step()

            loss_mat[i, 0, 0] += enc_loss.item()
            loss_mat[i, 1, 0] += idnetifier_loss.item()
            loss_mat[i, 2, 0] += pred_loss.item()
            loss_mat[i, 3, 0] += clus_loss.item()

        with torch.no_grad():
            for step, (x_val, y_val) in enumerate(val_loader):
                y_pred, probs = model.forward_pass(x_val)

                loss_weights = class_weight(y_val)

                common_loss = calc_pred_loss(y_val, y_pred, loss_weights)

                enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                 + calc_l1_l2_loss(part=model.Encoder) 

                idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(layers=[model.Identifier.fc2])

                pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])

                clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)

                loss_mat[i, 0, 1] += enc_loss.item()
                loss_mat[i, 1, 1] += idnetifier_loss.item()
                loss_mat[i, 2, 1] += pred_loss.item()
                loss_mat[i, 3, 1] += clus_loss.item()

            if i >= 30:
                if loss_mat[i, 0, 1] < best_loss:
                    count = 0
                    best_loss = loss_mat[i, 0, 1]
                    torch.save(model.state_dict(), './best_model')
                else:
                    count += 1
                    if count >= 50:
                        model.load_state_dict(torch.load('./best_model'))
        lr_scheduler.step(loss_mat[i, 0, 1])
        cluster_lr_scheduler.step(loss_mat[i, 0, 1])

#     print(calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3]), calc_l1_l2_loss(part=model.Encoder) + calc_l1_l2_loss(layers=[model.Identifier.fc2]))

    model.load_state_dict(torch.load('./best_model'))

    real, preds = [], []
    with torch.no_grad():
        for _, (x, y) in enumerate(test_loader):
            y_pred, _ = model.forward_pass(x)
            preds.extend(list(y_pred.cpu().detach().numpy()))
            real.extend(list(y.cpu().detach().numpy()))

    auc = roc_auc_score(real, preds, average=None)

    labels_true, labels_pred = np.argmax(real, axis=1), np.argmax(preds, axis=1)

    # Compute F1
    f1 = f1_score(labels_true, labels_pred, average=None)

    # Compute Recall
    rec = recall_score(labels_true, labels_pred, average=None)

    # Compute NMI
    nmi = normalized_mutual_info_score(labels_true, labels_pred)

    print(f'AUCROC: \t{auc.mean():.5f}, \t{auc}')
    print(f'F1-score: \t{f1.mean():.5f}, \t{f1}')
    print(f'Recall: \t{rec.mean():.5f}, \t{rec}')
    print(f'NMI: \t\t{nmi:.5f}')
    
    results[index, 0] = auc.mean()
    results[index, 1] = f1.mean()
    results[index, 2] = rec.mean()
    results[index, 3] = nmi


MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:12<00:00, 640.32it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.65it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.76it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


AUCROC: 	0.75846, 	[0.86568115 0.77298751 0.75408039 0.64108964]
F1-score: 	0.29019, 	[0.         0.47993095 0.68082847 0.        ]
Recall: 	0.34957, 	[0.         0.85276074 0.54553191 0.        ]
NMI: 		0.09296

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 692.29it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.64it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.70it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


AUCROC: 	0.76823, 	[0.8677883  0.75090831 0.74081293 0.7134292 ]
F1-score: 	0.32112, 	[0.         0.4909621  0.79349817 0.        ]
Recall: 	0.34579, 	[0.         0.64570552 0.73744681 0.        ]
NMI: 		0.09323

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 698.65it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.68it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.92it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


AUCROC: 	0.78961, 	[0.88619732 0.79715642 0.76605815 0.70903487]
F1-score: 	0.33811, 	[0.         0.50681981 0.845629   0.        ]
Recall: 	0.34631, 	[0.         0.54141104 0.84382979 0.        ]
NMI: 		0.11115

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 731.20it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.57it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.86it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


AUCROC: 	0.76238, 	[0.88584613 0.76743124 0.75367494 0.64255067]
F1-score: 	0.32349, 	[0.         0.51809124 0.77586207 0.        ]
Recall: 	0.36176, 	[0.         0.75766871 0.6893617  0.        ]
NMI: 		0.11198

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 726.84it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.57it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.90it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.89it/s]


AUCROC: 	0.77184, 	[0.91246325 0.77239491 0.74987746 0.6526265 ]
F1-score: 	0.32470, 	[0.         0.50343249 0.79538639 0.        ]
Recall: 	0.35212, 	[0.         0.67484663 0.73361702 0.        ]
NMI: 		0.10018

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 721.34it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.51it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.92it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


AUCROC: 	0.74490, 	[0.87101437 0.74903076 0.73178479 0.62778046]
F1-score: 	0.31412, 	[0.         0.49534643 0.76114726 0.        ]
Recall: 	0.35164, 	[0.         0.73466258 0.67191489 0.        ]
NMI: 		0.09452

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 729.75it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.55it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.88it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.86it/s]


AUCROC: 	0.75511, 	[0.86612218 0.75408566 0.73148441 0.66873437]
F1-score: 	0.33747, 	[0.         0.49041534 0.85944939 0.        ]
Recall: 	0.33857, 	[0.         0.4708589  0.88340426 0.        ]
NMI: 		0.11696

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 728.05it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.56it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.91it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


AUCROC: 	0.78484, 	[0.93335511 0.77795118 0.76374945 0.6643148 ]
F1-score: 	0.33817, 	[0.         0.53524492 0.81741892 0.        ]
Recall: 	0.36348, 	[0.         0.68711656 0.76680851 0.        ]
NMI: 		0.13032

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 724.72it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.65it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.84it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.89it/s]


AUCROC: 	0.78245, 	[0.92308069 0.78952907 0.77596793 0.64122705]
F1-score: 	0.33205, 	[0.         0.52249135 0.80570246 0.        ]
Recall: 	0.36008, 	[0.         0.69478528 0.74553191 0.        ]
NMI: 		0.11790

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 689.16it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.64it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.81it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.89it/s]

AUCROC: 	0.76318, 	[0.90655831 0.75849588 0.72001746 0.6676463 ]
F1-score: 	0.31040, 	[0.         0.51346274 0.66734694 0.0608    ]
Recall: 	0.37687, 	[0.         0.62883436 0.55659574 0.3220339 ]
NMI: 		0.09700





In [18]:
for m, u, std in zip(metrics, results.mean(axis=0), results.std(axis=0)):
    print(f'{m}: {u:.3f} ({std:.3f})')

AUC: 0.768 (0.013)
F1 score: 0.323 (0.014)
Recall: 0.355 (0.010)
NMI: 0.107 (0.012)


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class FeatTimeAttention(nn.Module):
    def __init__(self, latent_dim, input_shape):
        super().__init__()

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.latent_dim = latent_dim
        T, D_f = input_shape
        # Define Kernel and Bias for Feature Projection
        self.kernel = torch.zeros(
            (1, 1, D_f, self.latent_dim), requires_grad=True).to(self.device)
        nn.init.xavier_uniform_(self.kernel)
        self.bias = torch.zeros(
            (1, 1, D_f, self.latent_dim), requires_grad=True).to(self.device)
        nn.init.uniform_(self.bias)

        # Define Time aggregation weights for averaging over time.
        self.unnorm_beta = torch.zeros((1, T, 1), requires_grad=True)
        nn.init.uniform_(self.unnorm_beta)

    def forward(self, x, latent):
        o_hat, _ = self.generate_latent_approx(x, latent)
        weights = self.calc_weights(self.unnorm_beta)
        # print(o_hat.shape, weights.shape)
        return torch.sum(torch.mul(o_hat.to(self.device), weights.to(self.device)), dim=1)

    def generate_latent_approx(self, x, latent):
        features = torch.mul(x.unsqueeze(-1), self.kernel) + self.bias
        features = F.relu(features)

        # calculate the score
        X_T, X = features, features.transpose(2, 3)
        # print(X_T.shape, X.shape)
        X_T_X_inv = torch.inverse(torch.matmul(X_T, X))
        # print(X_T.shape, latent.unsqueeze(-1).shape)
        X_T_y = torch.matmul(X_T, latent.unsqueeze(-1))

        score_hat = torch.matmul(X_T_X_inv, X_T_y)
        scores = torch.squeeze(score_hat)

        # print(scores.unsqueeze(-1).shape, features.shape)
        o_hat = torch.sum(torch.mul(scores.unsqueeze(-1), features), dim=2)

        return o_hat, scores

    def calc_weights(self, x):
        abs_x = torch.abs(x)
        return abs_x / torch.sum(abs_x, dim=1)


class Encoder(nn.Module):
    def __init__(self, input_shape, attention_hidden_dim, latent_dim, dropout):
        super().__init__()
        self.lstm1 = nn.LSTM(input_size=input_shape[1],
                             hidden_size=attention_hidden_dim,
                             num_layers=2,
                             dropout=dropout,
                             batch_first=True)
        self.lstm2 = nn.LSTM(input_size=attention_hidden_dim,
                             hidden_size=latent_dim,
                             num_layers=1,
                             batch_first=True)
        self.attention = FeatTimeAttention(latent_dim, input_shape)
        self.fc = nn.Linear(10*(input_shape[1]+latent_dim), latent_dim)

    def forward(self, x):
        latent_rep, _ = self.lstm1(x)
        latent_rep, _ = self.lstm2(latent_rep)
        x_latent = torch.cat((x, latent_rep),dim=2)
        x_latent_flat = torch.flatten(x_latent, start_dim=1)
        output = self.fc(x_latent_flat)
        return output


class Identifier(nn.Module):
    def __init__(self, input_dim, mlp_hidden_dim, dropout, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, mlp_hidden_dim)
        self.sigmoid1 = nn.Sigmoid()

        self.fc2 = nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        self.sigmoid2 = nn.Sigmoid()
        self.dropout1 = nn.Dropout(dropout)

        self.fc4 = nn.Linear(mlp_hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.sigmoid1(x)

        x = self.fc2(x)
        x = self.sigmoid2(x)
        x = self.dropout1(x)

        x = self.fc4(x)
        x = self.softmax(x)
        return x


class Predictor(nn.Module):
    def __init__(self, input_dim, mlp_hidden_dim, dropout, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, mlp_hidden_dim)
        self.sigmoid1 = nn.Sigmoid()

        self.fc2 = nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        self.sigmoid2 = nn.Sigmoid()
        self.dropout1 = nn.Dropout(dropout)

        self.fc3 = nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        self.sigmoid3 = nn.Sigmoid()
        self.dropout2 = nn.Dropout(dropout)

        self.fc4 = nn.Linear(mlp_hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.sigmoid1(x)

        x = self.fc2(x)
        x = self.sigmoid2(x)
        x = self.dropout1(x)

        x = self.fc3(x)
        x = self.sigmoid3(x)
        x = self.dropout2(x)

        x = self.fc4(x)
        x = self.softmax(x)
        return x


class MyLRScheduler():
    def __init__(self, optimizer, patience, min_lr, factor):
        self.optimizer = optimizer
        self.patience = patience
        self.min_lr = min_lr
        self.factor = factor
        self.wait = 0
        self.best_loss = float('inf')

    def step(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.wait = 0
                for param_group in self.optimizer.param_groups:
                    old_lr = param_group['lr']
                    new_lr = max(old_lr * self.factor, self.min_lr)
                    param_group['lr'] = new_lr


def calc_l1_l2_loss(part=None, layers=None):
    para = []
    if part:
        for parameter in part.parameters():
            para.append(parameter.view(-1))
        parameters = torch.cat(para)
    if layers:
        for layer in layers:
            para.extend(layer.parameters())
        parameters = torch.cat([p.view(-1) for p in para])
    return 1e-30 * torch.abs(parameters).sum() + 1e-30 * torch.square(parameters).sum()

In [20]:
results = np.zeros((len(seeds), 4))
for index, SEED in enumerate(seeds):
    torch.random.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = CustomDataset(time_range=(0, 10))

    # Stratified Sampling for train and val
    train_idx, test_idx = train_test_split(np.arange(len(dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(dataset.y,axis=-1))

    # Subset dataset for train and val
    train_val_dataset = dataset.get_subset(train_idx)
    test_dataset = dataset.get_subset(test_idx)

    train_idx,  val_idx = train_test_split(np.arange(len(train_val_dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(train_val_dataset.y,axis=-1))

    train_dataset = train_val_dataset.get_subset(train_idx)
    val_dataset = train_val_dataset.get_subset(val_idx)

    train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset)

    model = CamelotModel(input_shape=(train_dataset.x.shape[1], train_dataset.x.shape[2]), seed=SEED, num_clusters=10, latent_dim=64)
    model = model.to(device)

    train_x = torch.tensor(train_dataset.x).to(device)
    train_y = torch.tensor(train_dataset.y).to(device)
    val_x = torch.tensor(val_dataset.x).to(device)
    val_y = torch.tensor(val_dataset.y).to(device)

    model.initialize((train_x, train_y), (val_x, val_y))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    cluster_optim = torch.optim.Adam([model.cluster_rep_set], lr=0.001)

    lr_scheduler = MyLRScheduler(optimizer, patience=15, min_lr=0.00001, factor=0.25)
    cluster_lr_scheduler = MyLRScheduler(cluster_optim, patience=15, min_lr=0.00001, factor=0.25)

    loss_mat = np.zeros((100, 4, 2))

    best_loss = 1e5
    count = 0
    for i in trange(100):
        for step, (x_train, y_train) in enumerate(train_loader):
            optimizer.zero_grad()
            cluster_optim.zero_grad()

            y_pred, probs = model.forward_pass(x_train)

            loss_weights = class_weight(y_train)

            common_loss = calc_pred_loss(y_train, y_pred, loss_weights)

            enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
             + calc_l1_l2_loss(part=model.Encoder) 
            enc_loss.backward(retain_graph=True, inputs=list(model.Encoder.parameters()))

            idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(layers=[model.Identifier.fc2])
            idnetifier_loss.backward(retain_graph=True, inputs=list(model.Identifier.parameters()))

            pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])
            pred_loss.backward(retain_graph=True, inputs=list(model.Predictor.parameters()))

            clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
            clus_loss.backward(inputs=model.cluster_rep_set)

            optimizer.step()
            cluster_optim.step()

            loss_mat[i, 0, 0] += enc_loss.item()
            loss_mat[i, 1, 0] += idnetifier_loss.item()
            loss_mat[i, 2, 0] += pred_loss.item()
            loss_mat[i, 3, 0] += clus_loss.item()

        with torch.no_grad():
            for step, (x_val, y_val) in enumerate(val_loader):
                y_pred, probs = model.forward_pass(x_val)

                loss_weights = class_weight(y_val)

                common_loss = calc_pred_loss(y_val, y_pred, loss_weights)

                enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                 + calc_l1_l2_loss(part=model.Encoder) 

                idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(layers=[model.Identifier.fc2])

                pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])

                clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)

                loss_mat[i, 0, 1] += enc_loss.item()
                loss_mat[i, 1, 1] += idnetifier_loss.item()
                loss_mat[i, 2, 1] += pred_loss.item()
                loss_mat[i, 3, 1] += clus_loss.item()

            if i >= 30:
                if loss_mat[i, 0, 1] < best_loss:
                    count = 0
                    best_loss = loss_mat[i, 0, 1]
                    torch.save(model.state_dict(), './best_model')
                else:
                    count += 1
                    if count >= 50:
                        model.load_state_dict(torch.load('./best_model'))
        lr_scheduler.step(loss_mat[i, 0, 1])
        cluster_lr_scheduler.step(loss_mat[i, 0, 1])

#     print(calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3]), calc_l1_l2_loss(part=model.Encoder) + calc_l1_l2_loss(layers=[model.Identifier.fc2]))

    model.load_state_dict(torch.load('./best_model'))

    real, preds = [], []
    with torch.no_grad():
        for _, (x, y) in enumerate(test_loader):
            y_pred, _ = model.forward_pass(x)
            preds.extend(list(y_pred.cpu().detach().numpy()))
            real.extend(list(y.cpu().detach().numpy()))

    auc = roc_auc_score(real, preds, average=None)

    labels_true, labels_pred = np.argmax(real, axis=1), np.argmax(preds, axis=1)

    # Compute F1
    f1 = f1_score(labels_true, labels_pred, average=None)

    # Compute Recall
    rec = recall_score(labels_true, labels_pred, average=None)

    # Compute NMI
    nmi = normalized_mutual_info_score(labels_true, labels_pred)

    print(f'AUCROC: \t{auc.mean():.5f}, \t{auc}')
    print(f'F1-score: \t{f1.mean():.5f}, \t{f1}')
    print(f'Recall: \t{rec.mean():.5f}, \t{rec}')
    print(f'NMI: \t\t{nmi:.5f}')
    
    results[index, 0] = auc.mean()
    results[index, 1] = f1.mean()
    results[index, 2] = rec.mean()
    results[index, 3] = nmi


MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 724.94it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.51it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.78it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:48<00:00,  2.06it/s]


AUCROC: 	0.74739, 	[0.82877328 0.77482939 0.7145033  0.67147136]
F1-score: 	0.33127, 	[0.         0.52077238 0.80429813 0.        ]
Recall: 	0.35776, 	[0.         0.68251534 0.74851064 0.        ]
NMI: 		0.11608

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 680.76it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.49it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.90it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:50<00:00,  1.99it/s]


AUCROC: 	0.74947, 	[0.84624306 0.7394545  0.69009023 0.72210288]
F1-score: 	0.32191, 	[0.         0.48447961 0.80315315 0.        ]
Recall: 	0.34229, 	[0.         0.61042945 0.7587234  0.        ]
NMI: 		0.08950

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 710.55it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.37it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.98it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.01it/s]


AUCROC: 	0.79067, 	[0.8895214  0.7793479  0.76392875 0.72987919]
F1-score: 	0.30577, 	[0.         0.53669222 0.59365347 0.09271523]
Recall: 	0.44864, 	[0.         0.75153374 0.44978723 0.59322034]
NMI: 		0.11039

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 707.15it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.49it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.77it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.03it/s]


AUCROC: 	0.78227, 	[0.9286671  0.77446379 0.78896237 0.63697574]
F1-score: 	0.36807, 	[0.22916667 0.49012658 0.75297619 0.        ]
Recall: 	0.48457, 	[0.55       0.74233129 0.64595745 0.        ]
NMI: 		0.13074

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 689.20it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.51it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.96it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.01it/s]


AUCROC: 	0.74888, 	[0.89362137 0.76307975 0.65459935 0.6842337 ]
F1-score: 	0.33317, 	[0.         0.48291572 0.8497692  0.        ]
Recall: 	0.33736, 	[0.         0.48773006 0.86170213 0.        ]
NMI: 		0.10325

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 721.88it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.35it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.91it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.01it/s]


AUCROC: 	0.79460, 	[0.92165959 0.81997912 0.79530314 0.64146541]
F1-score: 	0.34689, 	[0.         0.56100478 0.82656994 0.        ]
Recall: 	0.37377, 	[0.         0.71932515 0.77574468 0.        ]
NMI: 		0.15096

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 704.38it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.51it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.77it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.01it/s]


AUCROC: 	0.79979, 	[0.91344332 0.80936385 0.77216637 0.70417223]
F1-score: 	0.35637, 	[0.11904762 0.53914067 0.69051322 0.07677543]
Recall: 	0.46456, 	[0.25       0.70245399 0.56680851 0.33898305]
NMI: 		0.12990

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:11<00:00, 680.38it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.48it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.95it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:50<00:00,  2.00it/s]


AUCROC: 	0.77607, 	[0.92494283 0.75360546 0.76150304 0.66422787]
F1-score: 	0.31610, 	[0.07174888 0.42314436 0.76948909 0.        ]
Recall: 	0.40944, 	[0.4        0.55521472 0.68255319 0.        ]
NMI: 		0.11068

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 709.67it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.32it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.94it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.02it/s]


AUCROC: 	0.51137, 	[0.48918654 0.52599343 0.49838927 0.53191847]
F1-score: 	0.21635, 	[0.         0.         0.86540232 0.        ]
Recall: 	0.25000, 	[0. 0. 1. 0.]
NMI: 		0.00000

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:10<00:00, 712.24it/s]
 50%|█████     | 50/100 [00:11<00:11,  4.52it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.80it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.02it/s]

AUCROC: 	0.78165, 	[0.9263476  0.78865138 0.75518235 0.65642632]
F1-score: 	0.32128, 	[0.         0.5242047  0.69802956 0.06289308]
Recall: 	0.38081, 	[0.         0.58128834 0.60297872 0.33898305]
NMI: 		0.11034





In [21]:
for m, u, std in zip(metrics, results.mean(axis=0), results.std(axis=0)):
    print(f'{m}: {u:.3f} ({std:.3f})')

AUC: 0.748 (0.081)
F1 score: 0.322 (0.039)
Recall: 0.385 (0.066)
NMI: 0.105 (0.039)
