In [None]:
import torch
import feature_groups
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import AUROC, Accuracy
from torch.utils.data import DataLoader, random_split
import os
from os import path
from dime.data_utils import ROSMAPDataset, get_group_matrix, get_xy, data_split
from dime import MaskingPretrainer, CMIEstimator
from pytorch_lightning import Trainer
from dime.utils import MaskLayerGrouped
import matplotlib.pyplot as plt
import pandas as pd

# Load Dataset

In [None]:
auc_metric = AUROC(task='multiclass', num_classes=2)
acc_metric = Accuracy(task='multiclass', num_classes=2)

device = torch.device('cuda', 1)

use_apoe=False
rosmap_feature_names = feature_groups.rosmap_feature_names
rosmap_feature_groups = feature_groups.rosmap_feature_groups

if not use_apoe:
    rosmap_feature_names = [f for f in rosmap_feature_names if f not in ['apoe4_1copy','apoe4_2copies']]

feature_groups_dict, feature_groups_mask = get_group_matrix(rosmap_feature_names, rosmap_feature_groups)
num_groups = len(feature_groups_mask)

cols_to_drop = []
if cols_to_drop is not None:
    rosmap_feature_names = [item for item in rosmap_feature_names if str(rosmap_feature_names.index(item)) not in cols_to_drop]

# Load dataset
train_dataset = ROSMAPDataset('./data', split='train', cols_to_drop=cols_to_drop, use_apoe=use_apoe)
d_in = train_dataset.X.shape[1]  
d_out = len(np.unique(train_dataset.Y))

val_dataset = ROSMAPDataset('./data', split='val', cols_to_drop=cols_to_drop, use_apoe=use_apoe)
test_dataset = ROSMAPDataset('./data', split='test', cols_to_drop=cols_to_drop, use_apoe=use_apoe)

df = pd.read_csv("./data/rosmap_feature_costs.csv", header=None)
if use_apoe:
    feature_costs = df[1].tolist()
else:
    feature_costs = df[~df[0].isin(['apoe4_1copy','apoe4_2copies'])][1].tolist()

# Set up Networks

In [None]:
# Set up architecture
hidden = 128
dropout = 0.3
d_in = train_dataset.X.shape[1]  # 121
d_out = len(np.unique(train_dataset.Y))  # 2
print(d_out)
# Outcome Predictor
predictor = nn.Sequential(
    nn.Linear(d_in + num_groups, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, d_out)).to(device)

# CMI Predictor
value_network = nn.Sequential(
    nn.Linear(d_in + num_groups, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, num_groups)).to(device)

test_dataloader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, pin_memory=True, num_workers=4)

val_dataloader = DataLoader(
        val_dataset, batch_size=128, shuffle=False, pin_memory=True, num_workers=4)

mask_layer = MaskLayerGrouped(append=True, group_matrix=torch.tensor(feature_groups_mask))
trainer = Trainer(
                    accelerator='gpu',
                    devices=[device.index],
                    precision=16
                )

# Evaluate Penalized Policy

In [None]:
for trial in range(1):
    results_dict = {"acc": {}}
    
    trained_model_path = f"<path_to_trained_model>"
    greedy_cmi_estimator = GreedyCMIEstimatorPL.load_from_checkpoint(trained_model_path,
                                                                     value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                     max_features=15,
                                                                     eps=0.05,
                                                                     loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                     val_loss_fn=auc_metric,
                                                                     eps_decay=0.2,
                                                                     eps_steps=10,
                                                                     patience=5,
                                                                     feature_costs=None
                                                            ).to(device)
    
    avg_num_features_lambda = []
    accuracy_scores_lambda = []
    all_masks_lambda=[]
    
    # Evaluation Mode lambda penalty
    lamda_values = [0.000001, 0.00001, 0.00007, 0.0003, 0.0005] + [0.004, 0.016, 0.07] 

    for lamda in lamda_values:
        metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader,feature_costs=None, lam=lamda)
    
        y = metric_dict['y']
        pred = metric_dict['pred']
        auc_score = auc_metric(pred.float(), y)
        final_masks = np.array(metric_dict['mask'])
        accuracy_scores_lamda.append(auc_score)
        avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))
        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score

        print(f"Lambda={lamda}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")
        all_masks_lamda.append(final_masks)
        
    with open(f'results/rosmap_lambda_ours_trial_{trial}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)

# Evaluate Budget Constrained Policy

In [None]:
import sklearn.metrics as metrics
freq = []

for trial in range(0, 1):
    results_dict = {"acc": {}}
    trained_model_path = f"<path_to_trained_model>"

    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(trained_model_path,
                                                                     value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                     max_features=15,
                                                                     eps=0.05,
                                                                     loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                     val_loss_fn=auc_metric,
                                                                     eps_decay=0.2,
                                                                     eps_steps=10,
                                                                     patience=5,
                                                                     feature_costs=None
#                                                              cmi_scaling='positive'
                                                            ).to(device)
    avg_num_features_budget = []
    accuracy_scores_budget = []
    all_masks_budget=[]
    max_budget_values =list(range(1, 15, 1))
    freq.append([0] * 43)

    for budget in max_budget_values:
        metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, budget=budget)
        
        y = metric_dict_budget['y']
        pred = metric_dict_budget['pred']
        conf_matrix = metrics.confusion_matrix(y, np.argmax(pred, axis=1))
        cls_acc = np.diag(conf_matrix) / np.sum(conf_matrix, 1) # accuracy per class
        cls_avg = np.sum(cls_acc) / conf_matrix.shape[0]
        print("Rebalanced accuracy: {}".format(cls_avg))
        accuracy_score = acc_metric(pred.float(), y)
        final_masks = np.array(metric_dict_budget['mask'])
        accuracy_scores_budget.append(accuracy_score)
        avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))
        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score
        print(f"Budget={budget}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")
        freq.append(list(sum(final_masks) / final_masks.shape[0]))

        all_masks_budget.append(final_masks)
    
    with open(f'results/rosmap_ours_costs_inference_trial_{trial}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)


In [None]:
import seaborn as sns
df = pd.DataFrame(np.array(freq))
plt.rcParams.update({'font.size': 17})

sns.heatmap(df,  cmap="YlGnBu")
plt.xlabel("Feature Index")
plt.ylabel("Avg. # Features")
plt.title("ROSMAP")
plt.savefig("ROSMAP_Selection_Freq.pdf", format="pdf", bbox_inches="tight")


In [None]:
import sklearn.metrics as metrics

feature_costs = np.array([10, 10, 10, 10, 10, 600, 600, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 60, 300, 120, 300, 180, 300, 300, 180, 180, 60, 180, 900, 450, 180, 1200, 600, 120, 300, 180, 180, 180, 60, 60, 60, 600])

for trial in range(0, 1):
    results_dict = {"acc": {}}
    trained_model_path = f"<path_to_trained_model>"

    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(trained_model_path,
                                                                     value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                     max_features=15,
                                                                     eps=0.05,
                                                                     loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                     val_loss_fn=auc_metric,
                                                                     eps_decay=0.2,
                                                                     eps_steps=10,
                                                                     patience=5,
                                                                     feature_costs=feature_costs
                                                            ).to(device)
    avg_num_features_budget = []
    accuracy_scores_budget = []
    all_masks_budget=[]
    max_budget_values = list(range(1, 15, 1))
    
    for budget in max_budget_values:
        metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=feature_costs, budget=budget)
        
        y = metric_dict_budget['y']
        pred = metric_dict_budget['pred']
        conf_matrix = metrics.confusion_matrix(y, np.argmax(pred, axis=1))
        cls_acc = np.diag(conf_matrix) / np.sum(conf_matrix, 1) # accuracy per class
        cls_avg = np.sum(cls_acc) / conf_matrix.shape[0]
        print("Rebalanced accuracy: {}".format(cls_avg))
        accuracy_score = acc_metric(pred.float(), y)
        final_masks = np.array(metric_dict_budget['mask'])
        accuracy_scores_budget.append(accuracy_score)
        avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))
        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score
        print(f"Budget={budget}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks * feature_costs, axis=1))}")

        all_masks_budget.append(final_masks)
    
    with open(f'results/rosmap_ours_costs_inference_trial_{trial}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
