In [1]:
import torch
import feature_groups
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import AUROC
from sklearn.metrics import accuracy_score, roc_auc_score
from torch.utils.data import DataLoader, random_split
import os
from os import path
from dime.data_utils import ROSMAPDataset, get_group_matrix, get_xy, MaskLayerGrouped, data_split
from dime.masking_pretrainer import MaskingPretrainer
from dime.greedy_model_pl import GreedyCMIEstimatorPL
from dime.utils import accuracy, auc, normalize, selection_without_cost
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# Load Dataset

In [None]:
device = torch.device('cuda', 1)

use_apoe=False
rosmap_feature_names = feature_groups.rosmap_feature_names
rosmap_feature_groups = feature_groups.rosmap_feature_groups

if not use_apoe:
    rosmap_feature_names = [f for f in rosmap_feature_names if f not in ['apoe4_1copy','apoe4_2copies']]

feature_groups_dict, feature_groups_mask = get_group_matrix(rosmap_feature_names, rosmap_feature_groups)
num_groups = len(feature_groups_mask)

cols_to_drop = []
if cols_to_drop is not None:
    rosmap_feature_names = [item for item in rosmap_feature_names if str(rosmap_feature_names.index(item)) not in cols_to_drop]

# Load dataset
train_dataset = ROSMAPDataset('./data', split='train', cols_to_drop=cols_to_drop, use_apoe=use_apoe)
d_in = train_dataset.X.shape[1]  
d_out = len(np.unique(train_dataset.Y))

val_dataset = ROSMAPDataset('./data', split='val', cols_to_drop=cols_to_drop, use_apoe=use_apoe)
test_dataset = ROSMAPDataset('./data', split='test', cols_to_drop=cols_to_drop, use_apoe=use_apoe)

df = pd.read_csv("./data/rosmap_feature_costs.csv", header=None)
if use_apoe:
    feature_costs = df[1].tolist()
else:
    feature_costs = df[~df[0].isin(['apoe4_1copy','apoe4_2copies'])][1].tolist()

# Set up Networks

In [None]:
# Set up architecture
hidden = 128
dropout = 0.3
d_in = train_dataset.X.shape[1]  # 121
d_out = len(np.unique(train_dataset.Y))  # 2
print(d_out)
# Outcome Predictor
predictor = nn.Sequential(
    nn.Linear(d_in + num_groups, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, d_out)).to(device)

# CMI Predictor
value_network = nn.Sequential(
    nn.Linear(d_in + num_groups, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, num_groups),
    nn.Sigmoid()).to(device)

# Tie weights
# value_network[0] =  predictor[0]
# value_network[3] = predictor[3]
test_dataloader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, pin_memory=True,
        drop_last=True, num_workers=4)

val_dataloader = DataLoader(
        val_dataset, batch_size=128, shuffle=False, pin_memory=True,
        drop_last=True, num_workers=4)

mask_layer = MaskLayerGrouped(append=True, group_matrix=torch.tensor(feature_groups_mask))


# Evaluate Penalized Policy

In [None]:
for trial in range(1):
    results_dict = {"acc": {}}
    
    trained_model_path = f"<path_to_trained_model>"
    greedy_cmi_estimator = GreedyCMIEstimatorPL.load_from_checkpoint(trained_model_path,
                                                                     value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                    max_features=15,
                                                                    eps=0.05,
                                                                    loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                    val_loss_fn=auc,
                                                                    eps_decay=True,
                                                                    eps_decay_rate=0.2,
                                                                    patience=5,
                                                                    feature_costs=None,
                                                                    use_entropy=True
                                                            ).to(device)
    
    avg_num_features_lambda = []
    accuracy_scores_lambda = []
    all_masks_lambda=[]
    
    # Evaluation Mode lambda penalty
    lamda_values = [0.000001, 0.00001, 0.00007, 0.0003, 0.0005] + [0.004, 0.016, 0.07] 

    for lamda in lamda_values:
        metric_dict_lambda = greedy_cmi_estimator.evaluate(test_dataloader, performance_func=auc, 
                                                                        feature_costs=None, use_entropy=True, evaluation_mode='lamda-penalty', lamda=lamda)
        accuracy_score = metric_dict_lambda['performance']
        final_masks_lambda = metric_dict_lambda['final_masks']
        accuracy_scores_lambda.append(accuracy_score)
        results_dict['acc'][np.mean(np.sum(final_masks_lambda, axis=1))] = accuracy_score

        avg_num_features_lambda.append(np.mean(np.sum(final_masks_lambda, axis=1)))
        print(f"Lambda={lamda}, AUC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks_lambda, axis=1))}")

        all_masks_lambda.append(final_masks_lambda)
        
    with open(f'results/rosmap_lambda_ours_trial_{trial}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)

# Evaluate Budget Constrained Policy

In [None]:
for trial in range(0, 5):
    results_dict = {"acc": {}}
    trained_model_path = f"<path_to_trained_model>"

    greedy_cmi_estimator = GreedyCMIEstimatorPL.load_from_checkpoint(trained_model_path_no_costs,
                                                                     value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                    max_features=15,
                                                                    eps=0.05,
                                                                    loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                    val_loss_fn=auc,
                                                                    eps_decay=True,
                                                                    eps_decay_rate=0.2,
                                                                    patience=5,
                                                                    feature_costs=None,
                                                                    use_entropy=True
                                                            ).to(device)
    avg_num_features_budget = []
    accuracy_scores_budget = []
    all_masks_budget=[]
    max_budget_values = list(range(1, 15, 1))
    
    for budget in max_budget_values:
        metric_dict_budget  = greedy_cmi_estimator.evaluate(test_dataloader, performance_func=auc, 
                                                                        feature_costs=feature_costs, use_entropy=True, evaluation_mode='fixed-budget', budget=budget)#, selection_func=selection_without_cost)

        accuracy_score = metric_dict_budget['performance']
        final_masks_budget = metric_dict_budget['final_masks']
        accuracy_scores_budget.append(accuracy_score)
        avg_num_features_budget.append(np.mean(np.sum(final_masks_budget * feature_costs, axis=1)))
        print(f"Budget={budget}, AUC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks_budget * feature_costs, axis=1))}")
        results_dict['acc'][np.mean(np.sum(final_masks_budget * feature_costs, axis=1))] = accuracy_score
        all_masks_budget.append(final_masks_budget)
    
    with open(f'results/rosmap_ours_costs_inference_trial_{trial}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
