In [1]:
import numpy as np
import seaborn as sns
import pandas as pd
import numpy as np
import json
import torch
import os
import copy
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

from sklearn.metrics import multilabel_confusion_matrix, classification_report, roc_curve, auc, confusion_matrix, \
     RocCurveDisplay, precision_score, recall_score, average_precision_score, PrecisionRecallDisplay, precision_recall_curve, roc_auc_score
from tqdm import tqdm
from matplotlib.pyplot import figure


## Load the data

In [2]:
merged = pd.read_csv('/projectnb/vkolagrp/projects/pet_prediction/testing_data.csv')
merged_pred = pd.read_csv('/projectnb/vkolagrp/projects/pet_prediction/model_predictions/stage1_predictions.csv')

In [3]:
len(merged) == len(merged_pred)

True

In [4]:
len(merged)

1833

In [5]:
merged['ID'] = merged['ID'].astype(str)
merged_pred['ID'] = merged_pred['ID'].astype(str)

In [6]:
cog_labels = ['amy_label', 'tau_label']

## Generate predictions

In [7]:
def roc_pr_save(sub_df, fname, subgroup, figname, sub_df1=None):
    print(len(sub_df))
    pred = merged_pred[merged_pred['ID'].isin(sub_df['ID'])].reset_index(drop=True)
    pred = pred.drop(['ID', 'cdr_CDRGLOB', 'COHORT'], axis=1)
    pred.to_csv(f"./source_data/efig1/efig{figname}_{fname}.csv", index=False)

### Sex

In [8]:
merged['his_SEX'].value_counts()

his_SEX
female    928
male      905
Name: count, dtype: int64

In [9]:
for fname in ['male', 'female']:
    sub_df = merged[~merged['his_SEX'].isna()][merged['his_SEX'] == fname].reset_index(drop=True)
    
    print(len(sub_df))
    
    roc_pr_save(sub_df, fname=fname, figname="1b", subgroup="sex")

905
905
928
928


### Ethnicities

In [10]:
print(f"Amyloid white counts: {dict(merged['his_NACCNIHR'].value_counts())['whi']}")
print(f"Tau white counts: {dict(merged[~merged['tau_label'].isna()]['his_NACCNIHR'].value_counts())['whi']}")

Amyloid white counts: 1612
Tau white counts: 723


In [11]:
print(f"Amyloid other counts: {sum(dict(merged['his_NACCNIHR'].value_counts()).values()) - dict(merged['his_NACCNIHR'].value_counts())['whi']}")
print(f"Tau other counts: {sum(dict(merged[~merged['tau_label'].isna()]['his_NACCNIHR'].value_counts()).values()) - dict(merged[~merged['tau_label'].isna()]['his_NACCNIHR'].value_counts())['whi']}")

Amyloid other counts: 212
Tau other counts: 114


In [12]:
sub_df = merged[~merged['his_NACCNIHR'].isna()][merged['his_NACCNIHR'] == 'whi'].reset_index(drop=True)
# generate_predictions_for_data_file(sub_df, f'{save_path}/race/', labels, 'whi')
roc_pr_save(sub_df, fname='whi', figname="1c", subgroup="race")

# sub_df = merged[merged['his_NACCNIHR'] == 'blk']
# generate_predictions_for_data_file(sub_df, f'{save_path}/race/', labels, 'blk')

sub_df = merged[~merged['his_NACCNIHR'].isna()][(merged['his_NACCNIHR'] != 'whi')].reset_index(drop=True)
# generate_predictions_for_data_file(sub_df, f'{save_path}/race/', labels, 'oth')
roc_pr_save(sub_df, fname='oth', figname="1c", subgroup="race")

1612
212


### age

In [13]:
Q1 = merged['his_NACCAGE'].quantile(0.25)
Q2 = merged['his_NACCAGE'].quantile(0.5)  # This is also the median
Q3 = merged['his_NACCAGE'].quantile(0.75)
Q2

73.6208076659822

In [14]:
print(f"Median: {np.median(merged['his_NACCAGE'])}")

Median: 73.6208076659822


In [15]:
print(len(merged[merged['his_NACCAGE'] < Q2]))
print(len(merged[merged['his_NACCAGE'] >= Q2]))

916
917


In [16]:
print(len(merged[~merged['tau_label'].isna()][merged['his_NACCAGE'] < Q2]))
print(len(merged[~merged['tau_label'].isna()][merged['his_NACCAGE'] >= Q2]))

415
428


In [17]:
fname = "age_below_median"
sub_df = merged[~merged['his_NACCAGE'].isna()][merged['his_NACCAGE'] < Q2].reset_index(drop=True)
# sub_df1 = nacc[nacc['his_NACCAGE'] <= nacc_Q2]
print(fname)
roc_pr_save(sub_df, fname=fname, figname="1a", subgroup="age")

fname = "age_above_median"
sub_df = merged[~merged['his_NACCAGE'].isna()][merged['his_NACCAGE'] >= Q2].reset_index(drop=True)
print(fname)
roc_pr_save(sub_df, fname=fname, figname="1a", subgroup="age")

age_below_median
916
age_above_median
917


### education

In [18]:
Q1 = merged['his_EDUC'].quantile(0.25)
Q2 = merged['his_EDUC'].quantile(0.5)  # This is also the median
Q3 = merged['his_EDUC'].quantile(0.75)
Q2

16.0

In [19]:
np.median( merged[~merged['his_EDUC'].isna()]['his_EDUC'])

16.0

In [20]:
print(len(merged[merged['his_EDUC'] < Q2]))
print(len(merged[merged['his_EDUC'] >= Q2]))

570
1261


In [21]:
print(len(merged[~merged['tau_label'].isna()][merged['his_EDUC'] < Q2]))
print(len(merged[~merged['tau_label'].isna()][merged['his_EDUC'] >= Q2]))

252
590


In [22]:
fname = "educ_below_median"
sub_df = merged[~merged['his_EDUC'].isna()][merged['his_EDUC'] < Q2].reset_index(drop=True)
# sub_df1 = nacc[nacc['his_NACCAGE'] <= nacc_Q2]
print(fname)
roc_pr_save(sub_df, fname=fname, figname="1d", subgroup="educ")

fname = "educ_above_median"
sub_df = merged[~merged['his_EDUC'].isna()][merged['his_EDUC'] >= Q2].reset_index(drop=True)
# sub_df1 = nacc[nacc['his_NACCAGE'] > nacc_Q2]
print(fname)
roc_pr_save(sub_df, fname=fname, figname="1d", subgroup="educ")

educ_below_median
570
educ_above_median
1261


## Create data for the figure

In [23]:
cog_labels = ['amy_label', 'tau_label']

In [24]:
def gen_roc_pr(y_true, y_pred, cog_labels, subgroup, fname):
    for i, fea in enumerate(cog_labels):
        y_true_ = np.array(y_true[:, i])
        y_pred_ = np.array(y_pred[:, i])
        mask = np.array([1 if not np.isnan(k) else 0 for k in y_true_])
        masked_y_true = y_true_[np.where(mask == 1)]
        masked_y_pred = y_pred_[np.where(mask == 1)]
        # fpr[fea], tpr[fea], thresholds[fea] = roc_curve(y_true=masked_y_true, y_score=masked_y_pred, pos_label=1, drop_intermediate=False)
        auc_score = roc_auc_score(masked_y_true, masked_y_pred)
        aupr_score = average_precision_score(masked_y_true, masked_y_pred)
        
        if "amy" in fea:
            perf_dict_amy[subgroup][fname]['AUROC'] = round(auc_score, 2)
            perf_dict_amy[subgroup][fname]['AUPR'] = round(aupr_score, 2)
        else:
            perf_dict_tau[subgroup][fname]['AUROC'] = round(auc_score, 2)
            perf_dict_tau[subgroup][fname]['AUPR'] = round(aupr_score, 2)
            

In [25]:
def roc_pr(sub_df, fname, subgroup):
    y_true_ =  np.array(sub_df[[f'{lab}_label' for lab in cog_labels]])
    scores_proba_ = np.array(sub_df[[f'{lab}_prob' for lab in cog_labels]])
    
    gen_roc_pr(y_true_, scores_proba_, cog_labels, subgroup, fname)
    

In [26]:
perf_dict_amy = {"sex": {"male": {}, "female": {}}, "age": {"age_above_median": {}, "age_below_median": {}}, "race": {"whi": {}, "oth": {}}, "educ": {"educ_above_median": {}, "educ_below_median": {}}}
perf_dict_tau = {"sex": {"male": {}, "female": {}}, "age": {"age_above_median": {}, "age_below_median": {}}, "race": {"whi": {}, "oth": {}}, "educ": {"educ_above_median": {}, "educ_below_median": {}}}

In [27]:
basedir = "./source_data/efig1"
efig1a_female = pd.read_csv(f"{basedir}/efig1b_female.csv")
efig1a_male = pd.read_csv(f"{basedir}/efig1b_male.csv")
efig1a_whi = pd.read_csv(f"{basedir}/efig1c_whi.csv")
efig1a_oth = pd.read_csv(f"{basedir}/efig1c_oth.csv")
efig1a_age_above_median = pd.read_csv(f"{basedir}/efig1a_age_above_median.csv")
efig1a_age_below_median = pd.read_csv(f"{basedir}/efig1a_age_below_median.csv")
efig1d_educ_above_median = pd.read_csv(f"{basedir}/efig1d_educ_above_median.csv")
efig1d_educ_below_median = pd.read_csv(f"{basedir}/efig1d_educ_below_median.csv")

In [28]:
roc_pr(efig1a_female, "female", "sex")
roc_pr(efig1a_male, "male", "sex")
roc_pr(efig1a_whi, "whi", "race")
roc_pr(efig1a_oth, "oth", "race")
roc_pr(efig1a_age_above_median, "age_above_median", "age")
roc_pr(efig1a_age_below_median, "age_below_median", "age")
roc_pr(efig1d_educ_above_median, "educ_above_median", "educ")
roc_pr(efig1d_educ_below_median, "educ_below_median", "educ")

In [29]:
perf_dict_amy

{'sex': {'male': {'AUROC': 0.79, 'AUPR': 0.76},
  'female': {'AUROC': 0.79, 'AUPR': 0.8}},
 'age': {'age_above_median': {'AUROC': 0.76, 'AUPR': 0.78},
  'age_below_median': {'AUROC': 0.8, 'AUPR': 0.78}},
 'race': {'whi': {'AUROC': 0.79, 'AUPR': 0.79},
  'oth': {'AUROC': 0.79, 'AUPR': 0.74}},
 'educ': {'educ_above_median': {'AUROC': 0.78, 'AUPR': 0.75},
  'educ_below_median': {'AUROC': 0.8, 'AUPR': 0.84}}}

In [30]:
perf_dict_tau

{'sex': {'male': {'AUROC': 0.79, 'AUPR': 0.48},
  'female': {'AUROC': 0.87, 'AUPR': 0.72}},
 'age': {'age_above_median': {'AUROC': 0.78, 'AUPR': 0.56},
  'age_below_median': {'AUROC': 0.88, 'AUPR': 0.67}},
 'race': {'whi': {'AUROC': 0.83, 'AUPR': 0.61},
  'oth': {'AUROC': 0.91, 'AUPR': 0.58}},
 'educ': {'educ_above_median': {'AUROC': 0.8, 'AUPR': 0.49},
  'educ_below_median': {'AUROC': 0.91, 'AUPR': 0.79}}}

In [31]:
rows = []
for subgroup, categories in perf_dict_amy.items():
    for category, metrics in categories.items():
        rows.append({'Subgroup': subgroup, 'Category': category, 'AUROC': metrics['AUROC'], 'AUPR': metrics['AUPR'], "Label" : "amy_label"})

df1 = pd.DataFrame(rows)

rows = []
for subgroup, categories in perf_dict_tau.items():
    for category, metrics in categories.items():
        rows.append({'Subgroup': subgroup, 'Category': category, 'AUROC': metrics['AUROC'], 'AUPR': metrics['AUPR'], "Label" : "tau_label"})

df2 = pd.DataFrame(rows)

data = pd.concat([df1, df2], axis=0).reset_index(drop=True)

In [32]:
data.to_csv("./source_data/efig1.csv", index=False)

In [33]:
data

Unnamed: 0,Subgroup,Category,AUROC,AUPR,Label
0,sex,male,0.79,0.76,amy_label
1,sex,female,0.79,0.8,amy_label
2,age,age_above_median,0.76,0.78,amy_label
3,age,age_below_median,0.8,0.78,amy_label
4,race,whi,0.79,0.79,amy_label
5,race,oth,0.79,0.74,amy_label
6,educ,educ_above_median,0.78,0.75,amy_label
7,educ,educ_below_median,0.8,0.84,amy_label
8,sex,male,0.79,0.48,tau_label
9,sex,female,0.87,0.72,tau_label
