In [1]:
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import os
import toml
import scipy
import pickle

from tqdm import tqdm
import json
# from adrd.data import _conf
import adrd.utils.misc
import torch
import monai
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, balanced_accuracy_score, average_precision_score, multilabel_confusion_matrix, classification_report, roc_curve, auc, RocCurveDisplay, precision_score, recall_score, PrecisionRecallDisplay, precision_recall_curve
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import KFold, StratifiedKFold
from icecream import ic
ic.disable()

from data.dataset_csv import CSVDataset
from adrd.model import ADRDModel
from adrd.utils.misc import get_and_print_metrics_multitask
from adrd.utils.misc import get_metrics, print_metrics, print_metrics_multitask

## Get subgroup results

In [2]:
# define paths and variables
basedir=".."

cnf_file = "toml_files/config_0224_amy_tau_no_csf_no_plasma.toml"
dat_file = pd.read_csv('/projectnb/vkolagrp/varuna/mri_pet/adrd_tool/data_varuna/data/0225/ADNI_HABS_NACCTEST_HARMONIZED.csv')

ckpt_path = "../ckpt/model_stage_1.pt"


config = toml.load(cnf_file)
device = 'cuda:0'
# file_name = "no_fhs_mris_freeze_10_epochs_early_stopping_no_RL"
file_name = "plot"


In [3]:

# print(f"Loading checkpoint {ckpt_path}")
mdl = ADRDModel.from_ckpt(ckpt_path, device=device)
print(f"Epoch: {torch.load(ckpt_path)['epoch']}")
print("All keys matched")

img_net="SwinUNETREMB"
img_mode=1

# To run without without MRIs
# img_net="NonImg"
# img_mode=-1

cuda:0
Downsample layers:  2
Epoch: 84
All keys matched


In [4]:
mdl.net_.modules_emb_src.keys()

odict_keys(['FS_MTL_VOLUME', 'FS_TEMPORAL_VOLUME', 'FS_PARIETAL_VOLUME', 'FS_OCCIPITAL_VOLUME', 'FS_FRONTAL_VOLUME', 'FS_3rd_ventricle_volume', 'FS_4th_ventricle_volume', 'FS_brain_stem_volume', 'FS_csf_volume', 'FS_left_accumbens_area_volume', 'FS_left_amygdala_volume', 'FS_left_caudate_volume', 'FS_left_cerebellum_cortex_volume', 'FS_left_cerebellum_white_matter_volume', 'FS_left_cerebral_white_matter_volume', 'FS_left_hippocampus_volume', 'FS_left_inf_lat_vent_volume', 'FS_left_lateral_ventricle_volume', 'FS_left_pallidum_volume', 'FS_left_putamen_volume', 'FS_left_thalamus_volume', 'FS_left_ventraldc_volume', 'FS_left_choroid_plexus_volume', 'FS_right_accumbens_area_volume', 'FS_right_amygdala_volume', 'FS_right_caudate_volume', 'FS_right_cerebellum_cortex_volume', 'FS_right_cerebellum_white_matter_volume', 'FS_right_cerebral_white_matter_volume', 'FS_right_hippocampus_volume', 'FS_right_inf_lat_vent_volume', 'FS_right_lateral_ventricle_volume', 'FS_right_pallidum_volume', 'FS_righ

In [5]:
prefix = set()
for k, v in config['feature'].items():
    prefix.add(k.split('_')[0])

In [6]:
print(prefix)

{'WB', 'cdr', 'faq', 'npiq', 'exam', 'blood', 'bat', 'gds', 'apoe', 'med', 'FS', 'cd', 'ph', 'his'}


In [7]:
len(prefix)

14

In [8]:
subgroup_prefixes = {'All': [], 'History': ['his', 'ph', 'med'], 'Neurological/Physical': ['exam'], 'MRI': ['WB', 'FS'], 'FAQ': ['faq'], 'Neuropsych Battery': ['bat',  'npiq', 'gds'], 'CDR': ['cd', 'cdr'], 'Plasma': ['blood'], 'APoE e4': ['apoe']} #, 'CSF': ['csf']}
subgroups = {}
for k, v in subgroup_prefixes.items():
    subgroups[k] = [key for key in list(config['feature'].keys()) if key.split('_')[0] in subgroup_prefixes[k]]

In [9]:
def generate_performance_report(dat_tst, y_pred, scores_proba):
    y_true = [{k:int(v) if v is not None else 0 for k,v in entry.items()} for entry in dat_tst.labels]
    mask = [{k:1 if v is not None else 0 for k,v in entry.items()} for entry in dat_tst.labels]

    y_true_dict = {k: [smp[k] for smp in y_true] for k in y_true[0]}
    y_pred_dict = {k: [smp[k] for smp in y_pred] for k in y_pred[0]}
    scores_proba_dict = {k: [smp[k] for smp in scores_proba] for k in scores_proba[0]}
    mask_dict = {k: [smp[k] for smp in mask] for k in mask[0]}

    met = {}
    for k in dat_tst.labels[0].keys():
        # print('Performance metrics of {}'.format(k))
        metrics = get_metrics(np.array(y_true_dict[k]), np.array(y_pred_dict[k]), np.array(scores_proba_dict[k]), np.array(mask_dict[k]))
        metrics.pop('Confusion Matrix')
        met[k] = metrics
        
    return met

In [10]:
vld_file = pd.read_csv("/projectnb/vkolagrp/skowshik/pet_project/mri_pet/adrd_tool/data_varuna/data/0225/val_0225_new_harmonization.csv")
labels = ['amy_label', 'tau_label']

In [11]:
def roc_auc_scores(y_true, y_pred, features):
    # n_classes = y_true.shape[1]

    tpr = dict()
    fpr = dict()
    auc_scores = dict()
    thresholds = dict()
        
    for i, fea in enumerate(features):
        y_true_ = np.array(y_true[:, i])
        y_pred_ = np.array(y_pred[:, i])
        mask = np.array([1 if not np.isnan(k) else 0 for k in y_true_])
        masked_y_true = y_true_[np.where(mask == 1)]
        masked_y_pred = y_pred_[np.where(mask == 1)]
        print(len(masked_y_true), len(masked_y_pred))
        print(round(roc_auc_score(masked_y_true, masked_y_pred), 4))
        fpr[fea], tpr[fea], thresholds[fea] = roc_curve(y_true=masked_y_true, y_score=masked_y_pred, pos_label=1, drop_intermediate=True)
        auc_scores[fea] = auc(fpr[fea], tpr[fea])

    return fpr, tpr, auc_scores, thresholds

In [None]:
print('Done.\nLoading validation dataset ...')
dat_vld = CSVDataset(dat_file=vld_file, cnf_file=cnf_file, mode=0, img_mode=1, mri_type=mri_type, stripped='_stripped_MNI')
print('Done.')
# print(dat_tst.features[10].keys())
# raise ValueError

scores_vld, scores_proba_vld, y_pred_vld, _ = mdl.predict(dat_vld.features, _batch_size=128, img_transform=None)

y_true = [{k:int(v) if v is not None else np.NaN for k,v in entry.items()} for entry in dat_vld.labels]
# mask = [{k:1 if v is not None else 0 for k,v in entry.items()} for entry in dat_tst.labels]

y_true_ = {f'{k}_label': [smp[k] for smp in y_true] for k in y_true[0] if k in vld_file.columns}

scores_proba_ = {f'{k}_prob': [smp[k] if isinstance(y_true[i][k], int) else np.NaN for i, smp in enumerate(scores_proba_vld)] for k in scores_proba_vld[0] if k in vld_file.columns}

y_true_df = pd.DataFrame(y_true_)
scores_proba_df = pd.DataFrame(scores_proba_)
df = pd.concat([y_true_df, scores_proba_df], axis=1)

y_true_ar =  np.array(df[[f'{lab}_label' for lab in labels]])
scores_proba_ar = np.array(df[[f'{lab}_prob' for lab in labels]])


fpr, tpr, auc_scores, thresholds = roc_auc_scores(y_true=y_true_ar, y_pred=scores_proba_ar, features=labels)
print('Done.')

## Run model delete

In [None]:
met_list_delete = {}
for k, v in subgroups.items(): 
    feature_list = [fea for fea in dat_file.columns if fea not in v]
    df = dat_file[feature_list]
    
    print(f'Loading testing dataset for subgroup {k} without keys {v}...')
    dat_tst = CSVDataset(dat_file=df, cnf_file=cnf_file, mode=0, img_mode=img_mode, mri_type=mri_type, stripped='_stripped_MNI')
    print('Done.')

    # generate model predictions
    print('Generating model predictions')
    scores, scores_proba, y_pred, outputs = mdl.predict(dat_tst.features, fpr=fpr, tpr=tpr, thresholds=thresholds, _batch_size=1024, img_transform=None)
    print('Done.')
    print('Generating performance reports')
    met = generate_performance_report(dat_tst, y_pred, scores_proba)
    met_list_delete[k] = met

In [14]:
with open(f'../figures/source_data/fig2d.pickle', 'wb') as handle:
    pickle.dump(met_list_delete, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Run model add

In [None]:
met_list_add = {}
order = ['ID'] + list(config['label'].keys())
for k, v in subgroups.items():
    if len(v) == 0:
        continue 
    order += v
    print(order)
    feature_list = [fea for fea in dat_file.columns if fea in order]
    df = dat_file[feature_list]
    
    print(f'Loading testing dataset with keys {order}...')
    dat_tst = CSVDataset(dat_file=df, cnf_file=cnf_file, mode=0, img_mode=img_mode, mri_type=mri_type, stripped='_stripped_MNI')
    print('Done.')

    # generate model predictions
    print('Generating model predictions')
    scores, scores_proba, y_pred, outputs = mdl.predict(dat_tst.features, fpr=fpr, tpr=tpr, thresholds=thresholds, _batch_size=1024, img_transform=None)
    print('Done.')
    print('Generating performance reports')
    met = generate_performance_report(dat_tst, y_pred, scores_proba)
    met_list_add[k] = met

In [17]:
with open(f'../figures/source_data/fig2c.pickle', 'wb') as handle:
    pickle.dump(met_list_add, handle, protocol=pickle.HIGHEST_PROTOCOL)