In [None]:
from google.colab import drive
drive.mount('/content/drive')
!pip install transformers

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import sys
# BASE_PATH = "/content/drive/My Drive/collab/"
BASE_PATH = "/Users/samir/Dev/projects/MIMIC/"
input_path = BASE_PATH+"MIMIC/DATA/input/"
output_path = BASE_PATH+"MIMIC/DATA/results/"
tmp_path = BASE_PATH+"MIMIC/DATA/processed/"
sys.path.append(BASE_PATH+"TADAT/") 
N_SEEDS=4
PLOT_VARS=["auroc","auprc","sensitivity","specificity"]

In [None]:
from datetime import datetime
import fnmatch
import matplotlib.pyplot as plt
import numpy as np
import os
from pdb import set_trace
import pandas as pd
import pickle
from sklearn.linear_model import SGDClassifier 
from sklearn.metrics import f1_score, confusion_matrix, roc_auc_score, auc, precision_recall_curve
from sklearn.metrics import precision_recall_fscore_support as score
import seaborn as sns
import warnings

#local
from tadat.pipeline import plots
from tadat.core import data, vectorizer, features, helpers, embeddings, berter, transformer_lms

warnings.filterwarnings("ignore")
sns.set(style="darkgrid")

In [None]:
def read_cache(path):
    pass
    X = None
    try:
        with open(path, "rb") as fi:            
            X = pickle.load(fi)
    except FileNotFoundError:
        pass
    return X

def write_cache(path, o):
    with open(path, "wb") as fo:
        pickle.dump(o, fo)

def clear_cache(cache_path, model="*", dataset="*", group="*", ctype="*"):
    assert ctype in ["*","res*","feats"]
    file_paths = os.listdir(cache_path)
    pattern = "{}_{}_{}_*_{}.pkl".format(dataset, model, group, ctype).lower()
    for fname in file_paths:
        if fnmatch.fnmatch(fname, pattern):
            os.remove(cache_path+"/"+fname)
            print("cleared file: {}".format(fname))
        
def plot_cached_results(cache_path, dataset, model):
    file_paths = os.listdir(cache_path)
    pattern = "{}_{}_*_all_res.pkl".format(dataset, model).lower()
    for fname in file_paths:
        if fnmatch.fnmatch(fname, pattern):
            R = list(read_cache(cache_path+fname))
            if "gender" in fname:
                gender_plots(*R)
            elif "ethnicity_binary" in fname:
                ethnicity_binary_plots(*R)                
            elif "ethnicity" in fname:
                ethnicity_plots(*R)

def get_deltas(results_G, results_O):
    #resuts
    df_G = pd.DataFrame(results_G)
    df_O = pd.DataFrame(results_O)
    #compute deltas
    df_delta = df_G.sub(df_O.iloc[:,2:])
    df_delta["model"] = df_G["model"]
    df_delta["seed"] = df_G["seed"]   
    return df_delta

def plot_densities(df, ax, title):
    ax.set_title(title)
    for y in PLOT_VARS:        
        try:
            df.plot.kde(ax=ax, x="seed", y=y)
        except:
            pass
    
def plot_performance(df, title):
    #plots
    fig, ax = plt.subplots(1,2, figsize=(18,5))
    plots.plot_df(df=df,ax=ax[0],x="seed",ys=["auroc","auprc","sensitivity","specificity"], annotation_size=10)
    
    fig.suptitle(title ,y=1.02)
    plot_densities(df, ax[1], "") 
    ax[0].legend(loc='best')
    ax[1].legend(loc='best')
    plt.tight_layout()

def read_dataset(path, dataset_name):
    df_notes = pd.read_csv("{}/notes.csv".format(path), sep="\t", header=0)
    df_train = pd.read_csv("{}/{}_train.csv".format(path, dataset_name), sep="\t", header=0)
    df_test  = pd.read_csv("{}/{}_test.csv".format(path, dataset_name), sep="\t", header=0)
    df_val   = pd.read_csv("{}/{}_val.csv".format(path, dataset_name), sep="\t", header=0)
    
    df_train = df_train.join(df_notes, on="SUBJECT_ID", how="inner", lsuffix="N_")
    df_test = df_test.join(df_notes, on="SUBJECT_ID", how="inner", lsuffix="N_")
    df_val = df_val.join(df_notes, on="SUBJECT_ID", how="inner", lsuffix="N_")
    
    return df_train, df_test, df_val    

In [None]:
def get_features(data, vocab_size, feature_type, word_vectors=None):
    if feature_type == "BOW-BIN":
        X = features.BOW(data, vocab_size,sparse=True)
    elif feature_type == "BOW-FREQ":
        X = features.BOW_freq(data, vocab_size,sparse=True)
    elif feature_type == "BOE-BIN":
        X = features.BOE(data, word_vectors,"bin")
    elif feature_type == "BOE-SUM": 
        X = features.BOE(data, word_vectors,"sum")
    else:
        raise NotImplementedError
    return X
    
def get_BERT_embedding(X, feature_type):
    X_cls, X_pool =  transformer_lms.transformer_encode_batches(X, batchsize=200, device="cuda")
    if feature_type == "BERT-POOL":
        return X_pool
    elif feature_type == "BERT-CLS":
        return X_cls


def featurize(df_train, df_test, feature_type, group_label, subgroup):
    df_test_G = df_test[df_test[group_label] == subgroup]
    df_test_O = df_test[df_test[group_label] != subgroup]
    
    print("{}: {} | others: {}".format(subgroup,len(df_test_G),len(df_test_O)))
    #transform the data into the right format
    train = data.read_dataframe(df_train, "TEXT", "Y")
    test_G = data.read_dataframe(df_test_G, "TEXT", "Y")
    test_O = data.read_dataframe(df_test_O, "TEXT", "Y")
    test = data.read_dataframe(df_test, "TEXT", "Y")

    #get vectorized train/test data 
    train_X = data.getX(train)
    test_X_G = data.getX(test_G)
    test_X_O = data.getX(test_O)
    test_X = data.getX(test)
    
    train_X, word_vocab = vectorizer.docs2idx(train_X)
    test_X_G,_ = vectorizer.docs2idx(test_X_G, word_vocab)
    test_X_O,_ = vectorizer.docs2idx(test_X_O, word_vocab)
    test_X,_ = vectorizer.docs2idx(test_X, word_vocab)
    
    #vectorize labels
    train_Y = data.getY(train)
    test_Y_G = data.getY(test_G) 
    test_Y_O = data.getY(test_O)   
    test_Y = data.getY(test)   
    
    label_vocab = vectorizer.get_labels_vocab(train_Y+test_Y)
    train_Y,_ = vectorizer.label2idx(train_Y, label_vocab)
    test_Y,_ = vectorizer.label2idx(test_Y, label_vocab)
    test_Y_G,_ = vectorizer.label2idx(test_Y_G, label_vocab)
    test_Y_O,_ = vectorizer.label2idx(test_Y_O, label_vocab)
    
    if "BOW" in feature_type:
        #extract features
        train_feats = get_features(train_X, len(word_vocab), feature_type)
        test_feats_G = get_features(test_X_G, len(word_vocab), feature_type)
        test_feats_O = get_features(test_X_O, len(word_vocab), feature_type)        
        test_feats = get_features(test_X, len(word_vocab), feature_type)        
    elif "BERT" in feature_type:
        train_feats = get_BERT_embedding(train, feature_type)
        test_feats_G = get_BERT_embedding(test_G, feature_type)
        test_feats_O = get_BERT_embedding(test_O, feature_type)
        test_feats = get_BERT_embedding(test, feature_type)
    else:
        raise NotImplementedError    

    return train_feats, train_Y, test_feats, test_Y, test_feats_G, test_Y_G, test_feats_O, test_Y_O, label_vocab

def run(data_path, dataset, feature_type, group_label, subgroup, split=0.8, cache_path=None):
    if "FINE-BERT" in feature_type:
        return run_finebert(data_path, dataset, feature_type, group_label, subgroup, split, cache_path)    
    feats_fname = "{}{}_{}_{}_{}_feats.pkl".format(cache_path, dataset, feature_type, group_label, subgroup).lower()    
    X=None
    #check if the features were already computed and cached    
    if cache_path: X = read_cache(feats_fname)      
    #if features were not cached, read the data and extract features
    if not X:
        df_train, df_test, df_val = read_dataset(data_path, dataset)
        X = featurize(df_train, df_test, feature_type, group_label, subgroup)
        #cache current features
        if cache_path: write_cache(feats_fname, X)            
    else:
        print("loaded cached features")  
    train_feats, train_Y, test_feats, test_Y, test_feats_G, test_Y_G, test_feats_O, test_Y_O, label_vocab = X        
    print("train set size: ", train_feats.shape[0])
    #train/test classifier for each random seed
    random_seeds = list(range(N_SEEDS))
    results = []
    results_g = []
    results_o = []
    
    for seed in random_seeds:        
        res_fname = "{}{}_{}_{}_{}_res{}.pkl".format(cache_path, dataset, feature_type, group_label, subgroup, seed ).lower()
        R=None
        #look for cached results
        if cache_path: R = read_cache(res_fname)                      
        if not R:
            model = SGDClassifier(loss="log", random_state=seed)
            model.fit(train_feats, train_Y)
            res = evaluate_classifier(model, test_feats, test_Y, label_vocab, feature_type, seed)
            res_g = evaluate_classifier(model, test_feats_G, test_Y_G, label_vocab, feature_type, seed)
            res_o = evaluate_classifier(model, test_feats_O, test_Y_O, label_vocab, feature_type, seed)
            #cache results
            if cache_path: write_cache(res_fname, [res, res_g, res_o])                
        else:
            print("loaded cached results | seed: {}".format(seed))
            res, res_g, res_o = R
        results.append(res)
        results_g.append(res_g)
        results_o.append(res_o)
    return results, results_g, results_o

def bert_featurize(df_train, df_test, df_val, group_label, subgroup, split):
    #split data into "group" and "others"
    df_test_G = df_test[df_test[group_label] == subgroup]
    df_test_O = df_test[df_test[group_label] != subgroup]    
    print("{}: {} | OTHERS: {}".format(subgroup,len(df_test_G),len(df_test_O)))
    #transform the data into the right format
    train = data.read_dataframe(df_train, "TEXT", "Y")
    test_G = data.read_dataframe(df_test_G, "TEXT", "Y")
    test_O = data.read_dataframe(df_test_O, "TEXT", "Y")
    test = data.read_dataframe(df_test,  "TEXT", "Y")
    val = data.read_dataframe(df_val,  "TEXT", "Y")    
    #get instances...
    train_X = data.getX(train)
    test_X_G = data.getX(test_G)
    test_X_O = data.getX(test_O)
    test_X = data.getX(test)
    val_X = data.getX(val)  
    #...and labels
    train_Y = data.getY(train)
    test_Y_G = data.getY(test_G) 
    test_Y_O = data.getY(test_O)   
    test_Y = data.getY(test)  
    val_Y = data.getY(val)      
    
    # #find max sentence length
    # all_docs = np.concatenate([train_X,test_X,val_X])
    # max_len = berter.max_doc_len(all_docs)
    # print('Max length: ', max_len)   
    max_len = 512
    #vectorize labels
    label_vocab = vectorizer.get_labels_vocab(train_Y+test_Y+val_Y)
    train_Y,_ = vectorizer.label2idx(train_Y, label_vocab)
    val_Y,_ = vectorizer.label2idx(val_Y, label_vocab)
    test_Y,_ = vectorizer.label2idx(test_Y, label_vocab)
    test_Y_G,_ = vectorizer.label2idx(test_Y_G, label_vocab)
    test_Y_O,_ = vectorizer.label2idx(test_Y_O, label_vocab)        
    #vectorize data
    train_inputs, train_masks, train_labels = berter.vectorize(train_X, train_Y, max_len)
    val_inputs, val_masks, val_labels = berter.vectorize(val_X, val_Y, max_len)
    test_inputs, test_masks, test_labels = berter.vectorize(test_X, test_Y, max_len)
    test_inputs_G, test_masks_G, test_labels_G = berter.vectorize(test_X_G, test_Y_G, max_len)
    test_inputs_O, test_masks_O, test_labels_O = berter.vectorize(test_X_O, test_Y_O, max_len)

    BATCH_SIZE = 32
    # Create the DataLoader for training and validation sets
    train_loader = berter.get_random_sample_loader(train_inputs, train_masks, train_labels, BATCH_SIZE)
    val_loader = berter.get_sequential_sample_loader(val_inputs, val_masks, val_labels, BATCH_SIZE)
    test_loader = berter.get_sequential_sample_loader(test_inputs, test_masks, test_labels, BATCH_SIZE)
    test_loader_G = berter.get_sequential_sample_loader(test_inputs_G, test_masks_G, test_labels_G, BATCH_SIZE)
    test_loader_O = berter.get_sequential_sample_loader(test_inputs_O, test_masks_O, test_labels_O, BATCH_SIZE)  

    return train_loader, train_Y, val_loader, val_Y, test_loader, test_Y, test_loader_G, test_Y_G, test_loader_O, test_Y_O, label_vocab

def run_finebert(data_path, dataset, feature_type, group_label, subgroup, split=0.8, cache_path=None):
    print("FINE BERT")    
    feats_fname = "{}{}_{}_{}_{}_feats.pkl".format(cache_path, dataset, feature_type, group_label, subgroup).lower()
    X=None
    #check if the features were already computed and cached
    if cache_path: X = read_cache(feats_fname)      
    #if features were not cached, read the data and extract features
    if not X:
        df_train, df_test, df_val = read_dataset(data_path, dataset)
        X = bert_featurize(df_train, df_test, df_val, feature_type, group_label, subgroup)
        #cache current features
        if cache_path: write_cache(feats_fname, X)            
    else:
        print("loaded cached features")  
    train_loader, train_Y, val_loader, val_Y, test_loader, test_Y, test_loader_G, test_Y_G, test_loader_O, test_Y_O, label_vocab = X        
    print("train set size: ",  len(train_loader))
    #train/test classifier for each random seed
    random_seeds = list(range(N_SEEDS))
    results = []
    results_g = []
    results_o = []

    #pool features (vs CLS)    
    pool = "POOL" in feature_type
    
    for seed in random_seeds:        
        res_fname = "{}{}_{}_{}_{}_res{}.pkl".format(cache_path, dataset, feature_type, group_label, subgroup, seed ).lower()
        R = None
        #look for cached results        
        if cache_path: R = read_cache(res_fname)    
        if not R:
            model = berter.BertClassifier(freeze_bert=True, pool=pool)
            model.fit(train_loader, val_loader, epochs=1, validation=True, seed=seed)        
            res = evaluate_classifier(model, test_loader, test_Y, label_vocab, feature_type, seed)            
            res_g = evaluate_classifier(model, test_loader_G, test_Y_G, label_vocab, feature_type, seed)
            res_o = evaluate_classifier(model, test_loader_O, test_Y_O, label_vocab, feature_type, seed)                        
            results.append(res)
            results_g.append(res_g)            
            results_o.append(res_o)
            #cache results            
            if cache_path: write_cache(res_fname, [res, res_g, res_o])                
        else:
            print("loaded cached results | seed: {}".format(seed))
            res, res_g, res_o = R
        results.append(res)
        results_g.append(res_g)
        results_o.append(res_o)
    return results, results_g, results_o

def evaluate_classifier(model, X_test, Y_test,
                   labels, model_name, random_seed, res_path=None):
    Y_hat = model.predict(X_test)
    Y_hat_prob = model.predict_proba(X_test)
    #get probabilities for the positive class
    Y_hat_prob = Y_hat_prob[:,labels[1]]
    
    microF1 = f1_score(Y_test, Y_hat, average="micro") 
    macroF1 = f1_score(Y_test, Y_hat, average="macro") 
    aurocc = roc_auc_score(Y_test, Y_hat_prob)
    prec, rec, thresholds = precision_recall_curve(Y_test, Y_hat)
    auprc = auc(rec, prec)
    tn, fp, fn, tp = confusion_matrix(Y_test, Y_hat).ravel()
    specificity = tn / (tn+fp)
    sensitivity = tp / (fn+tp)
    
    res = {"model":model_name, 
            "seed":random_seed,    
            "microF1":round(microF1,3),
            "macroF1":round(macroF1,3),
            "auroc":round(aurocc,3),
            "auprc":round(auprc,3),
            "specificity":round(specificity,3),
            "sensitivity":round(sensitivity,3)           
            }

    if res_path is not None:    
        helpers.save_results(res, res_path, sep="\t")
    return res

# Analyses

## Ethnicity 

In [None]:
def ethnicity_plot_deltas(df_delta_W,df_delta_N,df_delta_A,df_delta_H, title):
    df_delta = pd.concat([df_delta_W,df_delta_N,df_delta_A,df_delta_H])    
    #transform results into "long format"
    df_delta_long = df_delta.melt(id_vars=["seed","model","group"], value_vars=PLOT_VARS, 
                                        var_name="metric", value_name="delta")
    g = sns.catplot(x="metric", y="delta", data=df_delta_long, 
                    col="group",sharey=True,legend=False)
    ax1, ax2, ax3, ax4 = g.axes[0]
    ax1.axhline(0, ls='--',c="r")
    ax2.axhline(0, ls='--',c="r")
    ax3.axhline(0, ls='--',c="r")
    ax4.axhline(0, ls='--',c="r")
    lim = max(df_delta_long["delta"].abs()) + 0.05
    ax1.set_ylim([-lim,lim])
    ax2.set_ylim([-lim,lim])
    ax3.set_ylim([-lim,lim])
    ax4.set_ylim([-lim,lim])
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()  

def ethnicity_plot_densities(df_W, df_N, df_A, df_H, title):
    #plots
    fig, ax = plt.subplots(1,4, sharey=True, sharex=True, figsize=(18,5))
    plot_densities(df_W, ax[0], "White")
    plot_densities(df_N, ax[1], "Black")
    plot_densities(df_A, ax[2], "Asian")
    plot_densities(df_H, ax[3], "Hispanic")
    fig.suptitle(title,  y=1.02)
    plt.tight_layout()

def ethnicity_plots(df_res, df_res_W, df_res_N, df_res_A, df_res_H, df_res_delta_W, 
                      df_res_delta_N,df_res_delta_A, df_res_delta_H, title):
    plot_performance(df_res, title)
    ethnicity_plot_densities(df_res_W,df_res_N,df_res_A,df_res_H,title)
    ethnicity_plot_deltas(df_res_delta_W, df_res_delta_N,df_res_delta_A,df_res_delta_H, title)

def ethnicity_outcomes(data_path, dataset, feature_type, cache_path=None):

    results_W, results_G_W, results_O_W = run(data_path, dataset, feature_type, "ETHNICITY_LABEL", 
                                              "WHITE", split=0.8, cache_path=cache_path)
    results_N, results_G_N, results_O_N = run(data_path, dataset, feature_type, "ETHNICITY_LABEL", 
                                              "BLACK", split=0.8, cache_path=cache_path)
    results_A, results_G_A, results_O_A = run(data_path, dataset, feature_type, "ETHNICITY_LABEL", 
                                              "ASIAN", split=0.8, cache_path=cache_path)
    results_H, results_G_H, results_O_H = run(data_path, dataset, feature_type, "ETHNICITY_LABEL",
                                              "HISPANIC", split=0.8, cache_path=cache_path)

    #results
    df_delta_W = get_deltas(results_G_W,results_O_W)
    df_delta_N = get_deltas(results_G_N,results_O_N)
    df_delta_A = get_deltas(results_G_A,results_O_A)
    df_delta_H = get_deltas(results_G_H,results_O_H)

    df_res = pd.DataFrame(results_W)
    df_res_W = pd.DataFrame(results_G_W)
    df_res_N = pd.DataFrame(results_G_N)
    df_res_A = pd.DataFrame(results_G_A)
    df_res_H = pd.DataFrame(results_G_H)

    df_delta_W["group"] = ["White v Others"]*len(df_delta_W)
    df_delta_N["group"] = ["Black v Others"]*len(df_delta_N)
    df_delta_A["group"] = ["Asian v Others"]*len(df_delta_A)
    df_delta_H["group"] = ["Hispanic v Others"]*len(df_delta_H)
    
    return df_res, df_res_W, df_res_N, df_res_A, df_res_H, df_delta_W, df_delta_N, df_delta_A, df_delta_H

def ethnicity_analysis(data_path, dataset, feature_type, output_path, cache_path=None, plots=True):
    R  = ethnicity_outcomes(data_path, dataset, feature_type, cache_path)
    df_res, df_res_W, df_res_N, df_res_A, df_res_H, df_res_delta_W, df_res_delta_N,df_res_delta_A, df_res_delta_H = R    
    #save results
    title="{} x ethnicity x {}".format(dataset, feature_type).lower()        
    fname = "{}{}_{}_ethnicity_all_res.pkl".format(output_path, dataset, feature_type).lower()
    with open(fname, "wb") as fo:
        pickle.dump([df_res, df_res_W, df_res_N, df_res_A, df_res_H, df_res_delta_W, 
                     df_res_delta_N,df_res_delta_A, df_res_delta_H, title], fo)
    if plots:
        ethnicity_plots(df_res, df_res_W, df_res_N, df_res_A, df_res_H, 
                          df_res_delta_W, df_res_delta_N,df_res_delta_A, df_res_delta_H, title)


## Ethnicity Binary

In [None]:
def ethnicity_binary_plot_deltas(df_delta_W,df_delta_N, title):
    df_delta = pd.concat([df_delta_W,df_delta_N])    
    #transform results into "long format"
    df_delta_long = df_delta.melt(id_vars=["seed","model","group"], value_vars=PLOT_VARS, 
                                        var_name="metric", value_name="delta")

    g = sns.catplot(x="metric", y="delta", data=df_delta_long, 
                    col="group",sharey=True,legend=False)
    ax1, ax2 = g.axes[0]
    ax1.axhline(0, ls='--',c="r")
    ax2.axhline(0, ls='--',c="r")
    lim = max(df_delta_long["delta"].abs()) + 0.05
    ax1.set_ylim([-lim,lim])
    ax2.set_ylim([-lim,lim])
    plt.suptitle(title,y=1.02)
    plt.tight_layout()
    plt.show()  
    
def ethnicity_binary_plot_densities(df_W, df_N, title):
    #plots
    fig, ax = plt.subplots(1,2, sharey=True, sharex=True, figsize=(18,5))
    plot_densities(df_W, ax[0], "White")
    plot_densities(df_N, ax[1], "Non-White")
    fig.suptitle(title ,y=1.02)
    plt.tight_layout()

def ethnicity_binary_plots(df_res, df_res_W, df_res_N, df_res_delta_W, df_res_delta_N, title):
    plot_performance(df_res, title)
    ethnicity_binary_plot_densities(df_res_W,df_res_N,title)
    ethnicity_binary_plot_deltas(df_res_delta_W, df_res_delta_N, title)

def ethnicity_binary_outcomes(data_path, dataset, feature_type, cache_path=None):
    results_W, results_G_W, results_O_W = run(data_path, dataset, feature_type, "ETHNICITY_BINARY", 
                                              "WHITE", split=0.8, cache_path=cache_path)
    results_N, results_G_N, results_O_N = run(data_path, dataset, feature_type, "ETHNICITY_BINARY", 
                                              "NON-WHITE", split=0.8, cache_path=cache_path)
    #results
    df_delta_W = get_deltas(results_G_W,results_O_W)
    df_delta_N = get_deltas(results_G_N,results_O_N)    
    df_delta_W["group"] = ["White v Others"]*len(df_delta_W)
    df_delta_N["group"] = ["Non-White v Others"]*len(df_delta_N)
    df_res = pd.DataFrame(results_W)    
    df_res_W = pd.DataFrame(results_G_W)
    df_res_N = pd.DataFrame(results_G_N)
    
    return df_res, df_res_W, df_res_N, df_delta_W, df_delta_N
   
def ethnicity_binary_analysis(data_path, dataset, feature_type, output_path, cache_path=None, plots=True):
    df_res, df_res_W, df_res_N, df_res_delta_W, df_res_delta_N = ethnicity_binary_outcomes(data_path, dataset, 
                                                                                           feature_type, cache_path)    
    #save results
    title="{} x ethnicity-binary x {}".format(dataset, feature_type).lower()
    fname = "{}{}_{}_ethnicity_binary_all_res.pkl".format(output_path, dataset, feature_type).lower()
    with open(fname, "wb") as fo:
        pickle.dump([df_res, df_res_W, df_res_N, df_res_delta_W, df_res_delta_N, title], fo)

    if plots:
        ethnicity_binary_plots(df_res, df_res_W, df_res_N, df_res_delta_W, df_res_delta_N,title)
    
    

## Gender 

In [None]:
def gender_plot_deltas(df_delta, title):
    #transform results into "long format"
    df_delta_long = df_delta.melt(id_vars=["seed","model"], value_vars=PLOT_VARS, 
                                        var_name="metric", value_name="delta")
    
    lim = max(df_delta_long["delta"].abs()) + 0.05
    g = sns.catplot(x="metric", y="delta",  data=df_delta_long, sharey=True,legend=False)
    ax1 = g.axes[0][0]
    ax1.axhline(0, ls='--',c="r")
    ax1.set_ylim([-lim,lim])
    plt.suptitle(title,y=1.02)
    plt.tight_layout()
    plt.show()  

def gender_plot_densities(df_M, df_F, title):
    #plots
    fig, ax = plt.subplots(1,2, sharey=True, sharex=True, figsize=(18,5))
    plot_densities(df_M, ax[0], "Male") 
    plot_densities(df_F, ax[1], "Female") 
    fig.suptitle(title, y=1.02)
    plt.tight_layout()

def gender_outcomes(data_path, dataset, feature_type, cache_path):
    results, results_M, results_F = run(data_path, dataset, feature_type, 
                                        "GENDER", "M", split=0.8, cache_path=cache_path)
    #results
    df_delta = get_deltas(results_M,results_F)    
    df_res = pd.DataFrame(results)    
    df_res_M = pd.DataFrame(results_M)
    df_res_F = pd.DataFrame(results_F)
    
    return df_res, df_res_M, df_res_F, df_delta

def gender_plots(df_res, df_res_M, df_res_F, df_res_delta, title):
    plot_performance(df_res, title)
    gender_plot_densities(df_res_M, df_res_F, title)
    gender_plot_deltas(df_res_delta, title)    
    
def gender_analysis(data_path, dataset, feature_type, output_path, cache_path=None, plots=True):
    df_res, df_res_M, df_res_F, df_res_delta = gender_outcomes(data_path, dataset, feature_type, cache_path)
    #save results
    title="{} x gender x {}".format(dataset, feature_type).lower()
    fname = "{}{}_{}_gender_all_res.pkl".format(output_path, dataset, feature_type).lower()
    with open(fname, "wb") as fo:
        pickle.dump([df_res, df_res_M, df_res_F, df_res_delta, title], fo)        

    if plots:
        gender_plots(df_res, df_res_M, df_res_F, df_res_delta, title)
        

# Outcomes

In [None]:
model="BOW-BIN"
dataset="CAAOHD"
dataset="mini_"+dataset

In [None]:
gender_analysis(input_path, dataset, model, output_path, tmp_path)

In [None]:
ethnicity_binary_analysis(input_path, dataset, model, output_path, tmp_path)

In [None]:
ethnicity_analysis(input_path, dataset, model, output_path, tmp_path)

In [None]:
plot_cached_results(output_path, dataset, model)

In [None]:
clear_cache(tmp_path)