In [1]:
import torch
import feature_groups
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import AUROC
from torch.utils.data import DataLoader, random_split
import os
from os import path
from dime.data_utils import DenseDatasetSelected, get_group_matrix, get_xy, MaskLayerGrouped, data_split
from dime import MaskingPretrainer
from dime import GreedyCMIEstimatorPL
from dime.utils import normalize, selection_without_cost
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pickle

# Load Dataset

In [None]:
auc_metric = AUROC(task='multiclass', num_classes=2)

intub_feature_names = feature_groups.intub_feature_names
intub_feature_groups = feature_groups.intub_feature_groups
device = torch.device('cuda', 1)

cols_to_drop = []
if cols_to_drop is not None:
    intub_feature_names = [item for item in intub_feature_names if str(intub_feature_names.index(item)) not in cols_to_drop]

# Load dataset
dataset = DenseDatasetSelected('data/intub.csv', cols_to_drop=cols_to_drop)
d_in = dataset.X.shape[1]  # 121
d_out = len(np.unique(dataset.Y))  # 2
feature_groups_dict, feature_groups_mask = get_group_matrix(intub_feature_names, intub_feature_groups)
num_groups = len(feature_groups_mask) 
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(0))
# Find mean/variance for normalizing
x, y = get_xy(train_dataset)
mean = np.mean(x, axis=0)
std = np.std(y, axis=0)

# Normalize via the original dataset
dataset.X = dataset.X - mean


# Set up networks

In [None]:
# Set up architecture
hidden = 128
dropout = 0.3
d_in = dataset.X.shape[1]  # 121
d_out = len(np.unique(dataset.Y))  # 2
print(d_out)
# Outcome Predictor
predictor = nn.Sequential(
    nn.Linear(d_in + num_groups, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, d_out)).to(device)

# CMI Predictor
value_network = nn.Sequential(
    nn.Linear(d_in + num_groups, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, num_groups),
    nn.Sigmoid()).to(device)

test_dataloader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, pin_memory=True,
        drop_last=True, num_workers=4)

val_dataloader = DataLoader(
        val_dataset, batch_size=128, shuffle=False, pin_memory=True,
        drop_last=True, num_workers=4)


mask_layer = MaskLayerGrouped(append=True, group_matrix=torch.tensor(feature_groups_mask))

trainer = Trainer(
                    accelerator='gpu',
                    devices=[device.index],
                    precision=16
                )

# Evaluate Penalized Policy

In [None]:
for trial in range(5):
    results_dict = {"acc": {}}
    trained_model_path = f"<path_to_trained_model>"
    greedy_cmi_estimator = GreedyCMIEstimatorPL.load_from_checkpoint(trained_model_path,
                                                                     value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                    max_features=30,
                                                                    eps=0.1,
                                                                    loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                    val_loss_fn=auc_metric,
                                                                    eps_decay=0.2,
                                                                    eps_steps=10,
                                                                    patience=5,
                                                                    feature_costs=None,
                                                                    use_entropy=True
                                                            ).to(device)
    avg_num_features_lambda = []
    accuracy_scores_lambda = []
    all_masks_lambda=[]
    # Evaluation Mode lambda penalty
    lamda_values = [0.000001, 0.00001, 0.00007, 0.0003, 0.0005] + [0.0012, 0.016] 

    for lamda in lamda_values:
        metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader,feature_costs=None, lam=lamda)
    
        y = metric_dict['y']
        pred = metric_dict['pred']
        auc_score = auc_metric(pred, y)
        final_masks = np.array(metric_dict['mask'])
        accuracy_scores_lamda.append(auc_score)
        avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))
        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score

        print(f"Lambda={lamda}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")
        all_masks_lamda.append(final_masks)
    
    with open(f'results/intub_lamda_ours_trial_{trial}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)


# Evaluate with Budget Constrained Policy

In [None]:
for trial in range(5):
    results_dict = {"acc": {}}
    trained_model_path = f"<path_to_trained_model>"
    greedy_cmi_estimator = GreedyCMIEstimatorPL.load_from_checkpoint(trained_model_path,
                                                                     value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                    max_features=40,
                                                                    eps=0.05,
                                                                    loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                    val_loss_fn=auc_metric,
                                                                    eps_decay=0.2,
                                                                    eps_steps=10,
                                                                    patience=5,
                                                                    feature_costs=None,
                                                                    use_entropy=True
                                                            ).to(device)
    avg_num_features_budget = []
    accuracy_scores_budget = []
    all_masks_budget=[]
    max_budget_values = [1, 3, 5, 10, 15, 20, 25]
    for budget in max_budget_values:
        metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, budget=budget)
        
        y = metric_dict_budget['y']
        pred = metric_dict_budget['pred']
        accuracy_score = auc_metric(pred, y)
        final_masks = np.array(metric_dict_budget['mask'])
        accuracy_scores_budget.append(accuracy_score)
        avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))
        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score
        print(f"Budget={budget}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")

        all_masks_budget.append(final_masks)
    
    with open(f'results/intub_ours_trial_{trial}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)