In [4]:
%load_ext autoreload
%autoreload 2

import numpy as np  
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder, OrdinalEncoder
from sklearn.compose import make_column_transformer
from sklearn.pipeline import make_pipeline
from sklearn.impute import SimpleImputer
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
from sksurv.metrics import cumulative_dynamic_auc
from sksurv.nonparametric import CensoringDistributionEstimator
from itertools import combinations
import torch 
import torch.nn as nn
import torch.nn.functional as F
import pickle
import yaml
from tqdm import tqdm

import sys
sys.path.append("../")
from models import DiscreteNAM, NAM
from utils import discretize, get_dataset, get_bin_counts, get_discetized_run_data_survival, get_run_data, get_ebm_run_data, get_run_data_survival

sys.path.append("../run_scripts")
from epoch_functions import train_epoch_dys_pairs, test_epoch_dys_pairs, test_epoch_drsa, test_epoch_sa_transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def get_auc(surv_preds, y_train, y_test, eval_times):
    test_times = np.linspace(
        max(y_train["time"].min(), y_test[y_test["event"] > 0]["time"].min()) + 1e-4,
        min(y_train["time"].max(), y_test[y_test["event"] > 0]["time"].max()) - 1e-4,
        1000
    )
    surv_preds = surv_preds[
        :, 
        np.clip(
            np.searchsorted(eval_times.cpu().numpy(), test_times),
            0, surv_preds.shape[1]-1
        )
    ]
    risk_preds = -1 * np.log(np.clip(surv_preds, 1e-5, 10 - 1e-5))
    # Get time-dependent AUC
    _, mean_auc = cumulative_dynamic_auc(y_train, y_test, risk_preds, test_times)
    return mean_auc

In [5]:
def get_auc_discrete_nam_(dataset, seed, split, data_dict=None, use_feature_set=True):
    
    # Read in model
    model = torch.load(f"../model_saves/discrete_nam_survival_{dataset}_seed{seed}_split{split}.pt").to(device)
    
    # Read in args using yaml
    args_id = model.params_id
    with open(f"../run_parameters/discrete_nam_survival_{dataset}_seed{seed}_split{split}_params{args_id}.yaml", "r") as f:
        args = yaml.safe_load(f)
    
    
    if data_dict is None:
        data_dict = \
            get_discetized_run_data_survival(dataset, seed=seed, split=split, max_bins=args["max_bins"], use_feature_set=use_feature_set)
        
        
    if use_feature_set:
        selected_feats = data_dict["selected_feats"]
        selected_pairs = data_dict["selected_pairs"]
        
        X_test_discrete = data_dict["X_test_discrete"].iloc[:, selected_feats]
        X_test_interactions = data_dict["X_test_discrete"].values[:, selected_pairs]
    else:
        X_test_discrete = data_dict["X_test_discrete"]
        
        active_feats = model.active_feats.cpu().numpy()
        selected_pairs = list(combinations(active_feats, 2))
        
        X_test_interactions = X_test_discrete.values[:, selected_pairs]

    batch_size = args["batch_size"]

    test_dataset = torch.utils.data.TensorDataset(
        torch.FloatTensor(X_test_discrete.values),
        torch.FloatTensor(X_test_interactions), 
        torch.BoolTensor(data_dict["y_test"]["event"]),
        torch.FloatTensor(data_dict["y_test"]["time"].copy()),
        torch.FloatTensor(data_dict["pcw_obs_times_test"])
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    eval_times = data_dict["eval_times"].to(device)
    pcw_eval_times = data_dict["pcw_eval_times"].to(device)

    _, preds = test_epoch_dys_pairs(model, test_loader, eval_times, pcw_eval_times, model_mains=model)
    
    return preds, data_dict["y_train"], data_dict["y_val"], data_dict["y_test"], eval_times

In [6]:
def get_auc_discrete_nam(dataset, seed, use_feature_set=True):
    splits = [1, 2, 3, 4, 5]
    # splits = [1]
    
    # Get the data_dict for one of the splits
    # Since we only use the test set all the splits will have the same data
    data_dict = get_discetized_run_data_survival(dataset, seed=seed, split=splits[0], use_feature_set=use_feature_set)
    
    split_preds = 0
    for split in splits:
        preds, y_train, y_val, y_test, eval_times = get_auc_discrete_nam_(dataset, seed, split, data_dict=data_dict, use_feature_set=use_feature_set)
        split_preds += preds
    
    surv_preds = 1 - torch.sigmoid(split_preds / len(splits)).cpu().numpy()
    
    # y_train_auc = pd.concat([y_train, y_val])
    y_train_auc = np.concatenate([y_train, y_val])
    
    return get_auc(surv_preds, y_train_auc, y_test, eval_times)

In [7]:
def get_auc_drsa(dataset, seed, split, data_dict=None, use_feature_set=True):
        
    # Read in model
    model = torch.load(f"../model_saves/drsa_{dataset}_seed{seed}_split{split}.pt").to(device)
    
    # Read in args using yaml
    args_id = model.params_id
    with open(f"../run_parameters/drsa_{dataset}_seed{seed}_split{split}_params{args_id}.yaml", "r") as f:
        args = yaml.safe_load(f)
    
    if data_dict is None:
        data_dict = \
            get_run_data_survival(dataset, seed=seed, split=split, preprocess=True, use_feature_set=use_feature_set)
            
    X_test = data_dict["X_test"]
    y_test = data_dict["y_test"]

    batch_size = args["batch_size"]

    test_dataset = torch.utils.data.TensorDataset(
        torch.FloatTensor(X_test.values), 
        torch.BoolTensor(y_test["event"]),
        torch.FloatTensor(y_test["time"].copy())
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    eval_times = data_dict["eval_times"].to(device)

    _, test_preds = test_epoch_drsa(model, test_loader, eval_times)
    
    surv_preds = torch.cumprod(1 - torch.sigmoid(test_preds).squeeze(-1), dim=1).cpu().numpy()  
    
    # y_train_auc = pd.concat([y_train, y_val])
    y_train_auc = np.concatenate([data_dict["y_train"], data_dict["y_val"]])
    
    return get_auc(surv_preds, y_train_auc, y_test, eval_times)

In [2]:
def get_auc_sa_transformer_(dataset, seed, split, data_dict=None, use_feature_set=True):
        
    # Read in model
    model = torch.load(f"../model_saves/sa_transformer_{dataset}_seed{seed}_split{split}.pt").to(device)
    
    # Read in args using yaml
    args_id = model.params_id
    with open(f"../run_parameters/sa_transformer_{dataset}_seed{seed}_split{split}_params{args_id}.yaml", "r") as f:
        args = yaml.safe_load(f)
    
    if data_dict is None:
        data_dict = \
            get_run_data_survival(dataset, seed=seed, split=split, preprocess=True, use_feature_set=use_feature_set)
            
    X_test = data_dict["X_test"]
    y_test = data_dict["y_test"]

    batch_size = args["batch_size"]

    test_dataset = torch.utils.data.TensorDataset(
        torch.FloatTensor(X_test.values), 
        torch.BoolTensor(y_test["event"]),
        torch.FloatTensor(y_test["time"].copy())
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    eval_times = data_dict["eval_times"].to(device)

    _, test_preds = test_epoch_sa_transformer(model, test_loader, eval_times)
    
    return test_preds, data_dict["y_train"], data_dict["y_val"], data_dict["y_test"], eval_times

In [21]:
def get_auc_sa_transformer(dataset, seed, use_feature_set=False):
    splits = [1, 2, 3, 4, 5]
    # splits = [1]
    
    # Get the data_dict for one of the splits
    # Since we only use the test set all the splits will have the same data
    # data_dict = get_run_data_survival(dataset, seed=seed, split=splits[0], preprocess=True, use_feature_set=use_feature_set)
    
    split_preds = 0
    for split in splits:
        print("SPLIT", split)
        preds, y_train, y_val, y_test, eval_times = get_auc_sa_transformer_(dataset, seed, split, use_feature_set=use_feature_set)
        split_preds += preds

    surv_preds = torch.cumprod(split_preds / len(splits), dim=1).cpu().numpy()
    
    # y_train_auc = pd.concat([y_train, y_val])
    y_train_auc = np.concatenate([y_train, y_val])
    
    return get_auc(surv_preds, y_train_auc, y_test, eval_times)

In [9]:
def predict_all_cum_hazard_(ebm, x, times, monte_carlo=False):
    x_input = pd.DataFrame(
        np.tile(x.values.reshape(-1, 1), len(times)).T,
        columns=list(x.index)
    )
    x_input["time"] = times
    preds = ebm.predict_proba(x_input)[:, 1]
    if monte_carlo:
        return pd.Series(
            times * np.cumsum(preds) / np.arange(1, len(preds)+1)
        )
    else:
        return pd.Series(np.cumsum(preds))

In [11]:
def get_auc_ebm(dataset, seed, n_eval_times=100, use_feature_set=False):
    
    split_preds = 0
    
    with open(f"../model_saves/ebm_{dataset}_seed{seed}.pkl", "rb") as f:
        ebm = pickle.load(f)
       [] 
    data_dict = get_ebm_run_data(dataset, seed)
    X_test = data_dict["X_test"]
    y_train = data_dict["y_train"]
    y_test = data_dict["y_test"]
    
    if use_feature_set:
        # Save selected_features to yaml
        with open(f"../feature_sets/coxnet_{dataset}_seed{seed}.yaml", "r") as f:
            selected_features = yaml.safe_load(f)
        
        # selected_features = ebm.feature_names_in_
        
        X_test = X_test[selected_features]
        
    pred_times = np.linspace(
        max(y_train["time"].min(), y_test[y_test["event"]]["time"].min()) + 1e-4,
        min(y_train["time"].max(), y_test[y_test["event"]]["time"].max()) - 1e-4,
        1000
    )

    tqdm.pandas()
    cum_hazards = X_test.progress_apply(
        lambda row: predict_all_cum_hazard_(ebm, row, pred_times, monte_carlo=True), 
        axis=1
    ).values
        
    # Get evaluation times before train/val split
    quantiles = torch.quantile(
        torch.FloatTensor(y_train["time"].copy()),
        torch.linspace(0, 1, n_eval_times+2)
    )
    eval_times = quantiles[1:-1]

    surv_preds = np.exp(-cum_hazards)
    
    return get_auc(surv_preds, y_train, y_test, eval_times)

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 7)

In [13]:
data_dict = get_run_data_survival("heart_failure_survival", seed=11, split=1, preprocess=True, use_feature_set=False)



In [16]:
data_dict["X_train"].shape

(170123, 3907)

In [15]:
sat_model = torch.load(f"../model_saves/sa_transformer_heart_failure_survival_seed11_split1.pt").to(device)
sat_model.embed

SrcEmbed(
  (w): Linear(in_features=3907, out_features=64, bias=True)
  (norm): LayerNorm()
)

In [22]:
models = ["SA-Transformer"]
datasets = ["heart_failure_survival"]
seeds = [10, 11, 12, 13, 14]

results = []
for model in models:
    print(f"Model: {model}")
    for dataset in datasets:
        
        if dataset == "heart_failure_survival":
            use_feature_set = True
        else:
            use_feature_set = False
        
        print(f"Dataset: {dataset}")
        for seed in seeds:
            print("Seed", seed)
            if model == "EBM":
                auc = get_auc_ebm(dataset, seed, use_feature_set=use_feature_set)
                
            elif model == "Discrete_NAM":
                auc = get_auc_discrete_nam(dataset, seed, use_feature_set=use_feature_set)
                
            elif model == "SA-Transformer":
                auc = get_auc_sa_transformer(dataset, seed)
                
            results.append([
                model, dataset, seed, auc
            ])
            
results = pd.DataFrame(results, columns=["model", "dataset", "seed", "auc"])

Model: SA-Transformer
Dataset: heart_failure_survival
Seed 10
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

Seed 11
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

Seed 12
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

Seed 13
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

Seed 14
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

In [23]:
mean_df = results.groupby(["model", "dataset"])["auc"].mean()
std_df = results.groupby(["model", "dataset"])["auc"].std()

means_df = mean_df.round(3).map("{:.3f}".format)
stds_df = std_df.round(3).map("{:.3f}".format)

# Convert to string and add plus/minus in between
means_df.astype(str) + " ± " + stds_df.astype(str)

model           dataset               
SA-Transformer  heart_failure_survival    0.860 ± 0.003
Name: auc, dtype: object

In [24]:
models = ["SA-Transformer"]
datasets = ["unos"]
seeds = [10, 11, 12, 13, 14]

results = []
for model in models:
    print(f"Model: {model}")
    for dataset in datasets:
        
        if dataset == "heart_failure_survival":
            use_feature_set = True
        else:
            use_feature_set = False
        
        print(f"Dataset: {dataset}")
        for seed in seeds:
            print("Seed", seed)
            if model == "EBM":
                auc = get_auc_ebm(dataset, seed, use_feature_set=use_feature_set)
                
            elif model == "Discrete_NAM":
                auc = get_auc_discrete_nam(dataset, seed, use_feature_set=use_feature_set)
                
            elif model == "SA-Transformer":
                auc = get_auc_sa_transformer(dataset, seed)
                
            results.append([
                model, dataset, seed, auc
            ])
            
results_unos = pd.DataFrame(results, columns=["model", "dataset", "seed", "auc"])

Model: SA-Transformer
Dataset: unos
Seed 10
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

Seed 11
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

Seed 12
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

Seed 13
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

Seed 14
SPLIT 1


                                                 

SPLIT 2


                                                 

SPLIT 3


                                                 

SPLIT 4


                                                 

SPLIT 5


                                                 

In [26]:
mean_df = results_unos.groupby(["model", "dataset"])["auc"].mean()
std_df = results_unos.groupby(["model", "dataset"])["auc"].std()

means_df = mean_df.round(3).map("{:.3f}".format)
stds_df = std_df.round(3).map("{:.3f}".format)

# Convert to string and add plus/minus in between
means_df.astype(str) + " ± " + stds_df.astype(str)

model           dataset
SA-Transformer  unos       0.714 ± 0.002
Name: auc, dtype: object

In [14]:
results_df = pd.DataFrame(results, columns=["model", "dataset", "seed", "auc"])
results_df.groupby(["model", "dataset"]).agg(["mean", "std"])

Unnamed: 0_level_0,Unnamed: 1_level_0,seed,seed,auc,auc
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
model,dataset,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
EBM,heart_failure_survival,12.0,1.581139,0.834311,0.003518
EBM,support,12.0,1.581139,0.808115,0.007501


In [27]:
models = ["SA-Transformer"]
datasets = ["unos"]
seeds = [10, 11, 12, 13, 14]
splits = [1, 2, 3, 4, 5]

results = []
for model in models:
    print(f"Model: {model}")
    for dataset in datasets:
        
        if dataset == "heart_failure_survival":
            use_feature_set = True
        else:
            use_feature_set = False
        
        print(f"Dataset: {dataset}")
        for seed in seeds:
            if model == "EBM":
                # auc = get_auc_ebm(dataset, seed, use_feature_set=use_feature_set)
                with open(f"../model_saves/ebm_{dataset}_seed{seed}.pkl", "rb") as f:
                    ebm = pickle.load(f)
                    
                fit_time = ebm.fit_time
                
            elif model == "Discrete_NAM":
                fit_time = 0
                for split in splits:
                    nam = torch.load(f"../model_saves/discrete_nam_survival_{dataset}_seed{seed}_split{split}.pt")
                    fit_time += nam.fit_time
                    
            elif model == "SA-Transformer":
                fit_time = 0
                for split in splits:
                    nam = torch.load(f"../model_saves/sa_transformer_{dataset}_seed{seed}_split{split}.pt")
                    fit_time += nam.fit_time
                
            results.append([
                model, dataset, seed, fit_time / 60
            ])
            
results = pd.DataFrame(results, columns=["model", "dataset", "seed", "fit_time"])

Model: SA-Transformer
Dataset: unos


In [28]:
mean_df = results.groupby(["model", "dataset"])["fit_time"].mean()
std_df = results.groupby(["model", "dataset"])["fit_time"].std()

means_df = mean_df.round(3).map("{:.3f}".format)
stds_df = std_df.round(3).map("{:.3f}".format)

# Convert to string and add plus/minus in between
means_df.astype(str) + " ± " + stds_df.astype(str)

model           dataset
SA-Transformer  unos       199.160 ± 13.076
Name: fit_time, dtype: object

In [32]:
X, y = get_dataset('adult')
X

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States
4,18,,103497,Some-college,10,Never-married,,Own-child,White,Female,0,0,30,United-States
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48837,27,Private,257302,Assoc-acdm,12,Married-civ-spouse,Tech-support,Wife,White,Female,0,0,38,United-States
48838,40,Private,154374,HS-grad,9,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,40,United-States
48839,58,Private,151910,HS-grad,9,Widowed,Adm-clerical,Unmarried,White,Female,0,0,40,United-States
48840,22,Private,201490,HS-grad,9,Never-married,Adm-clerical,Own-child,White,Male,0,0,20,United-States


In [35]:
X["fnlwgt"].describe()

count    4.884200e+04
mean     1.896641e+05
std      1.056040e+05
min      1.228500e+04
25%      1.175505e+05
50%      1.781445e+05
75%      2.376420e+05
max      1.490400e+06
Name: fnlwgt, dtype: float64

In [25]:
data_dict = \
    get_run_data_survival(dataset, seed=seed, split=split, preprocess=True, use_feature_set=False)



In [32]:
sat_auc = get_auc_sa_transformer(dataset, seed, split, data_dict=data_dict, use_feature_set=False)
sat_auc

                                                 

0.8555127994793543

In [5]:
dataset = "unos"
seed = 10
split = 1

sat_auc = get_auc_sa_transformer(dataset, seed, split, use_feature_set=False)
sat_auc

                                                 

0.7133271415103789

In [12]:
dataset = "unos"
seed = 10
split = 1

dnam_auc = get_auc_discrete_nam(dataset, seed, use_feature_set=False)
dnam_auc

Number of categorical features 6


                                                 

0.7167600645391862

In [28]:
next(iter(test_loader))[2].shape

torch.Size([128])

In [18]:
dataset = "heart_failure_survival"
seed = 10
split = 1

drsa_auc = get_auc_drsa(dataset, seed, split, use_feature_set=False)
drsa_auc

                                                 

0.8164614272677884

In [22]:
dataset = "heart_failure_survival"
seed = 10
split = 1

drsa_auc = get_auc_sa_transformer(dataset, seed, split, use_feature_set=False)
drsa_auc

                                       

RuntimeError: The size of tensor a (100) must match the size of tensor b (128) at non-singleton dimension 1

In [11]:
dataset = "heart_failure_survival"
seeds = [10, 11, 12, 13, 14]
aucs = []
for seed in seeds:
    auc = get_auc_discrete_nam(dataset, seed)
    aucs.append(auc)

aucs = pd.DataFrame(aucs, columns=["AUC"])
aucs["Model"] = "DiscreteNAM"

aucs.groupby("Model").agg(["mean", "std"])

Number of categorical features 3


                                                  

Number of categorical features 2


                                                  

Number of categorical features 1


                                                  

Number of categorical features 1


                                                  

Number of categorical features 2


                                                  

Unnamed: 0_level_0,AUC,AUC
Unnamed: 0_level_1,mean,std
Model,Unnamed: 1_level_2,Unnamed: 2_level_2
DiscreteNAM,0.843652,0.002705


In [12]:
aucs

Unnamed: 0,AUC,Model
0,0.841135,DiscreteNAM
1,0.842667,DiscreteNAM
2,0.847954,DiscreteNAM
3,0.841998,DiscreteNAM
4,0.844508,DiscreteNAM


In [22]:
auc = get_auc_discrete_nam("heart_failure_survival", 10)

                                                  

In [23]:
auc

0.8407390195549266

In [7]:
ebm_auc = get_auc_ebm("heart_failure_survival", 12, use_feature_set=True)

100%|██████████| 53164/53164 [05:23<00:00, 164.36it/s]


In [8]:
ebm_auc

0.8391356798950814

In [6]:
ebm_auc

0.8428519870451628

In [7]:
ebm_auc = get_auc_ebm("support", 10, use_feature_set=False)

100%|██████████| 1821/1821 [00:40<00:00, 44.60it/s]


In [12]:
dataset = "support"
seeds = [10, 11, 12, 13, 14]
aucs = []
for seed in seeds:
    auc = get_auc_ebm(dataset, seed)
    aucs.append(auc)
    
print(np.mean(aucs), np.std(aucs))

  0%|          | 0/1821 [00:00<?, ?it/s]

100%|██████████| 1821/1821 [00:46<00:00, 39.05it/s]
100%|██████████| 1821/1821 [00:32<00:00, 56.77it/s]
100%|██████████| 1821/1821 [00:46<00:00, 38.86it/s]
100%|██████████| 1821/1821 [00:46<00:00, 38.79it/s]
100%|██████████| 1821/1821 [00:39<00:00, 45.69it/s]


0.6621206961715033 0.009496662233958645


In [11]:
dataset = "heart_failure_survival"
seeds = [10, 11, 12, 13, 14]
splits = [1, 2, 3, 4, 5]

ebm_times = []
for seed in seeds:
    with open(f"../model_saves/ebm_{dataset}_seed{seed}.pkl", "rb") as f:
        ebm = pickle.load(f)
    ebm_times.append(ebm.fit_time)
    
dnam_times = []
for seed in seeds:
    dnam_time = 0
    for split in splits:
        model = torch.load(f"../model_saves/discrete_nam_survival_{dataset}_seed{seed}_split{split}.pt")
        dnam_time += model.fit_time
    dnam_times.append(dnam_time)
    
np.mean(ebm_times), np.std(ebm_times), np.mean(dnam_times), np.std(dnam_times)

(10719.62092385292, 414.15938192756215, 2227.926358985901, 235.66472519167203)

In [34]:
dnam_auc = get_auc_discrete_nam("heart_failure_survival", 10)
dnam_auc

                                                  

0.8401130087869512

In [7]:
dataset = "heart_failure_survival"
seed = 10

with open(f"../model_saves/ebm_{dataset}_seed{seed}.pkl", "rb") as f:
    ebm = pickle.load(f)

In [9]:
len(ebm.feature_names_in_)

34

In [39]:
(ebm.fit_time // 60)

174.0

In [9]:
auc = get_auc_discrete_nam("support", 11, use_feature_set=False)
auc

Number of categorical features 3


                                               

0.8250152203486563

In [35]:
dataset = "heart_failure_survival"
seed = 10
split = 1
use_feature_set = True

data_dict = get_discetized_run_data_survival(dataset, seed=seed, split=split, max_bins=32, use_feature_set=use_feature_set)

In [19]:
data_dict["X_train_discrete"]["condition_long_term_201826"]

86427     1
146253    1
181480    1
100195    1
37023     1
         ..
210536    1
20339     1
191396    2
33482     1
116877    1
Name: condition_long_term_201826, Length: 170123, dtype: int64

In [36]:
len(data_dict["selected_feats"])

32

In [27]:
data_dict["X_train_discrete"].columns[data_dict["cat_cols_indices"]]

Index(['cat_meas:3050380_0', 'smoking'], dtype='object')

In [26]:
[c for c in data_dict["X_train_discrete"].columns if "gender" in c]

['gender']

In [29]:
data_dict["X_train_discrete"].select_dtypes(include="category").columns

Index([], dtype='object')