In [1]:
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
from datetime import datetime, timedelta
import pickle
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def medcode_to_read(medcode, lkp_med_code_dic):
    try:
        return lkp_med_code_dic[medcode]
    except KeyError:
        return medcode

In [3]:
def medcode_to_term(medcode, lkp_med_term_dic):
    try:
        return lkp_med_term_dic[medcode]
    except KeyError:
        return medcode

In [4]:
def read_to_medcode(readcode, lkp_med):
    try:
        return lkp_med[lkp_med['readcode']==str(readcode)]['medcode'].item()
    except ValueError:
        return readcode

In [5]:
def cprd_to_bnf(cprdcode, lkp_bnf):
    return lkp_bnf[lkp_bnf['bnfcode']==str(cprdcode)]['bnf'].item()

In [6]:
def bnf_to_chapter(bnfcode, lkp_prd):
    return np.unique(lkp_prd[lkp_prd['bnfcode']==bnfcode]['bnfchapter'])

In [3]:
def build_stat_dic(patids, med_str, df_c, df_t_min, df_t_max_6, df_f, df_p, target_dic, tdqm_display=True):
    
    stat_dic = {}

    for patid in tqdm(patids) if tdqm_display==True else patids:
        
        # calculate the patient's approx dob (assumed half way through year of birth i.e. 1/7)
        dob = datetime.strptime(str(df_p[df_p['patid']==patid]['yob'].item()), "%Y") + timedelta(days=181)
        dob = dob.strftime('%Y-%m-%d')
        
        # filter clinical data to this patid only (to speed up processing)
        temp_df_c = df_c[df_c['patid']==patid]
        
        # min event date data
        eventdate_min = df_f[df_f['patid']==patid][f'min_eventdate_{med_str}'].item()
        # get a list of the "current" medications the patient was on at min event date
        temp_df_t_min = df_t_min[(df_t_min['patid']==patid)]
        meds_min = []
        for med in temp_df_t_min.bnfcode.unique():
            meds_min.append(med)
        # calculate the number of "current" medications
        num_med_min = len(meds_min)
        # get a list of read codes detailing patient's clinical experience prior to min event date
        temp_df_c_min = temp_df_c[temp_df_c['eventdate']<=eventdate_min]
        readcodes_min = []
        for readcode in temp_df_c_min.readcode.unique():
            readcodes_min.append(readcode)
        # calculate age at the min event date
        age_min = (datetime.strptime((eventdate_min), '%Y-%m-%d')-datetime.strptime((dob), '%Y-%m-%d')).days/365
    
        # max_6 event date data
        eventdate_max_6 = df_f[df_f['patid']==patid][f'max_eventdate_{med_str}_less_6_month'].item()
        # get a list of the "current" medications the patient was on at max_6 event date
        temp_df_t_max_6 = df_t_max_6[(df_t_max_6['patid']==patid)]
        meds_max_6 = []
        for med in temp_df_t_max_6.bnfcode.unique():
            meds_max_6.append(med)
        # calculate the number of "current" medications
        num_med_max_6 = len(meds_max_6)
        # get a list of read codes detailing patient's clinical experience prior to max_6 event date
        temp_df_c_max_6 = temp_df_c[temp_df_c['eventdate']<=eventdate_max_6]
        readcodes_max_6 = []
        for readcode in temp_df_c_max_6.readcode.unique():
            readcodes_max_6.append(readcode)
        # calculate age at the max_6 event date
        age_max_6 = (datetime.strptime((eventdate_max_6), '%Y-%m-%d')-datetime.strptime((dob), '%Y-%m-%d')).days/365
        
        # max event date data
        eventdate_max = df_f[df_f['patid']==patid][f'max_eventdate_{med_str}'].item()
        # time on drug variables
        delta_max_6 = datetime.strptime(eventdate_max_6, '%Y-%m-%d') - datetime.strptime(eventdate_min, '%Y-%m-%d')
        delta_max = datetime.strptime(eventdate_max, '%Y-%m-%d') - datetime.strptime(eventdate_min, '%Y-%m-%d')
        # drug duration zero at point of first instance (eventdate_min)
        
        # build dictionary
        stat_dic[patid] = {'approx. dob':dob,
                           'gender':df_p[df_p['patid']==patid]['gender'].item(),
                           
                           'eventdate_min':eventdate_min,
                           'drug_duration_min':0,
                           'meds_min':meds_min,
                           'num_med_min':num_med_min,
                           'readcodes_min':readcodes_min,
                           'age_min':age_min,
                           
                           'eventdate_max_6':eventdate_max_6,
                           'drug_duration_max_6':delta_max_6.days,
                           'meds_max_6':meds_max_6,
                           'num_med_max_6':num_med_max_6,
                           'readcodes_max_6':readcodes_max_6,
                           'age_max_6':age_max_6,
                           
                           'eventdate_max':eventdate_max,
                           'drug_duration_max':delta_max.days,
                           
                           'target': {key:[1 if patid in target_dic[key] else 0] for key in target_dic.keys()}}
                            
    return stat_dic

In [7]:
def add_neg_patids(patids_pos, patids_neg_dic, version, random_seed=0):
    
    np.random.seed(random_seed)
    patids_neg = np.random.choice(patids_neg_dic[version], size=len(patids_pos), replace=False)  
    
    patids = list(set.union(set(patids_pos), set(patids_neg)))
    
    random.seed(random_seed)
    random.shuffle(patids)
    return patids

In [9]:
def build_data(stat_dic, patid_list, feature_dic, disease_dic, lkp_bnf, lkp_prd, eventdate, target, tdqm_display=True):
    # dictionary for mapping column index to feature
    dic = {}
    
    labels = []
    X_gender, X_age = [], []
    X_drug_duration, X_num_med = [], []
    X_bnf = []
    X_bnf_pca = []
    X_db, X_hf, X_tm, X_ld, X_rd, X_ht = [], [], [], [], [], []
    
    n_bnf = len(list(lkp_bnf['bnfcode']))
    
    # build up data a patient at a time, according to the user selected features in the feature_dic
    for patid in tqdm(patid_list) if tdqm_display==True else patid_list:
    
        labels.append(stat_dic[patid]['target'][target])
        
        if feature_dic['patient_profile']==True:
            if eventdate=='eventdate_min':
                X_gender.append(stat_dic[patid]['gender']-1)
                X_age.append([stat_dic[patid]['age_min']])
            elif eventdate=='eventdate_max_6':
                X_gender.append(stat_dic[patid]['gender']-1)
                X_age.append([stat_dic[patid]['age_max_6']])
                
        if feature_dic['drug_metrics_profile']==True:
            if eventdate=='eventdate_max_6':
                X_drug_duration.append(stat_dic[patid]['drug_duration_max_6'])
                X_num_med.append(stat_dic[patid]['num_med_max_6'])
            if eventdate=='eventdate_min':
#                 X_drug_duration.append(0) # drug duration is 0 at min prescription date
                X_drug_duration.append(stat_dic[patid]['drug_duration_min'])
                X_num_med.append(stat_dic[patid]['num_med_min'])
                
        if feature_dic['bnf_profile']==True or feature_dic['bnf_pca_profile']['include']==True:
            if eventdate=='eventdate_min':
                X_bnf.append([0 if (k+1) not in stat_dic[patid]['meds_min'] else 1 for k in range(n_bnf)])
            elif eventdate=='eventdate_max_6':
                X_bnf.append([0 if (k+1) not in stat_dic[patid]['meds_max_6'] else 1 for k in range(n_bnf)])
                
        if feature_dic['disease_profile']==True:
            if eventdate=='eventdate_min':
                X_db.append([1 if bool(set(stat_dic[patid]['readcodes_min'])&set(disease_dic['db'])) else 0])
                X_hf.append([1 if bool(set(stat_dic[patid]['readcodes_min'])&set(disease_dic['hf'])) else 0])
                X_tm.append([1 if bool(set(stat_dic[patid]['readcodes_min'])&set(disease_dic['tm'])) else 0])
                X_ld.append([1 if bool(set(stat_dic[patid]['readcodes_min'])&set(disease_dic['ld'])) else 0])
                X_rd.append([1 if bool(set(stat_dic[patid]['readcodes_min'])&set(disease_dic['rd'])) else 0])
                X_ht.append([1 if bool(set(stat_dic[patid]['readcodes_min'])&set(disease_dic['ht'])) else 0])
            if eventdate=='eventdate_max_6':
                X_db.append([1 if bool(set(stat_dic[patid]['readcodes_max_6'])&set(disease_dic['db'])) else 0])
                X_hf.append([1 if bool(set(stat_dic[patid]['readcodes_max_6'])&set(disease_dic['hf'])) else 0])
                X_tm.append([1 if bool(set(stat_dic[patid]['readcodes_max_6'])&set(disease_dic['tm'])) else 0])
                X_ld.append([1 if bool(set(stat_dic[patid]['readcodes_max_6'])&set(disease_dic['ld'])) else 0])
                X_rd.append([1 if bool(set(stat_dic[patid]['readcodes_max_6'])&set(disease_dic['rd'])) else 0])
                X_ht.append([1 if bool(set(stat_dic[patid]['readcodes_max_6'])&set(disease_dic['ht'])) else 0])

    # if pca transformation is required, then apply this at the top level
    if feature_dic['bnf_pca_profile']['include']==True:
        X_bnf_pca = feature_dic['bnf_pca_profile']['pca'].transform(X_bnf)
            
    # re-format the data
    labels = np.array(labels)
    
    X_gender, X_age = np.reshape(np.array(X_gender),(len(X_gender),1)), np.reshape(np.array(X_age),(len(X_age),1))
    X_drug_duration = np.reshape(np.array(X_drug_duration),(len(X_drug_duration),1))
    X_num_med = np.reshape(np.array(X_num_med),(len(X_num_med),1))
    X_bnf = np.array(X_bnf)
    X_bnf_pca = np.array(X_bnf_pca)
    X_db = np.reshape(np.array(X_db), (len(X_db), 1))
    X_hf = np.reshape(np.array(X_hf), (len(X_hf), 1))
    X_tm = np.reshape(np.array(X_tm), (len(X_tm), 1))
    X_ld = np.reshape(np.array(X_ld), (len(X_ld), 1))
    X_rd = np.reshape(np.array(X_rd), (len(X_rd), 1))
    X_ht = np.reshape(np.array(X_ht), (len(X_ht), 1))
        
    # dummy column to start building from
    X = np.zeros((len(patid_list),1))
    # idx to map column to feature
    idx = 0
    
    # build a dictionary mapping the idx value of each feature to a description of what the feature is
    if feature_dic['patient_profile']==True:
        X = np.concatenate((X, X_gender), axis=1)
        dic[idx] = 'gender'
        idx += 1
        X = np.concatenate((X, X_age), axis=1)
        dic[idx] = 'age'
        idx += 1
    if feature_dic['drug_metrics_profile']==True:
        X = np.concatenate((X, X_drug_duration), axis=1)
        dic[idx] = 'drug_duration'
        idx += 1
        X = np.concatenate((X, X_num_med), axis=1)
        dic[idx] = 'num_med'
        idx += 1
    if feature_dic['bnf_profile']==True:
        X = np.concatenate((X, X_bnf), axis=1)
        unknown_idx = 0
        for bnfcode in range(n_bnf):
            # !note there are duplicate mappings, so appending an index to make each mapping unique
            try:
                mapping = bnf_to_chapter(cprd_to_bnf(bnfcode+1, lkp_bnf), lkp_prd).item()
                if mapping=='-':
                    dic[idx] = f'Unknown Chapter {unknown_idx}'
                    unknown_idx += 1
                elif mapping in list(dic.values()):
                    dic[idx] = f'{mapping} (repeated)' # don't need to index here bcs max duplication is 2
                else:
                    dic[idx] = mapping
            except ValueError:
                dic[idx] = f'Unknown Chapter {unknown_idx}'
                unknown_idx += 1
            idx += 1
    if feature_dic['bnf_pca_profile']['include']==True:
        X = np.concatenate((X, X_bnf_pca), axis=1)
        # KernelPCA has diff attributes to PCA and SparsePCA
        try:
            n_components = feature_dic['bnf_pca_profile']['pca'].n_components_
        except AttributeError:
            n_components = feature_dic['bnf_pca_profile']['pca'].alphas_.shape[1] 
        for pc in range(n_components):
            dic[idx] = f'principalcomponent{pc+1}'
            idx += 1
    if feature_dic['disease_profile']==True:
        X = np.concatenate((X, X_db), axis=1) 
        dic[idx] = 'diabetes'
        idx += 1
        X = np.concatenate((X, X_hf), axis=1)
        dic[idx] = 'heartfailure'
        idx += 1
        X = np.concatenate((X, X_tm), axis=1)
        dic[idx] = 'thrombosis'
        idx += 1
        X = np.concatenate((X, X_ld), axis=1)
        dic[idx] = 'liverdisease'
        idx += 1
        X = np.concatenate((X, X_rd), axis=1)
        dic[idx] = 'renaldisease'
        idx += 1
        X = np.concatenate((X, X_ht), axis=1)
        dic[idx] = 'hypertension'
        idx += 1
        
    # delete dummy column before returning
    X = np.delete(X, 0, axis=1)
        
    return X, labels, dic

In [None]:
def get_stat_dic(drug):
    
    # stat_dic which contains statistics realting to all patients sampled for each specific drug
    filepath = f'S:\\CALIBER_17_205R\\MSc\\Oliver\\Python Code\\{str(drug)}_analysis\\stat_dic_{str(drug)}.p'
    stat_dic = pickle.load(open(filepath, 'rb'))
    
    # infer targets avaliable from stat_dic by indexing the first avaliable entry
    target_list = list(stat_dic[list(stat_dic.keys())[0]]['target'].keys())
    
    # get negative patids for each version - ONLY CONSIDERING VERSION 1 IN TESTING BETWEEN DRUGS
    patids_neg_dic = {'v1':[patid for patid in list(stat_dic.keys()) 
                            if stat_dic[patid]['target']['patids_v1_neg'][0]==1]}
    
    # create a 6 month filter
    filter_6_mnth = []
    for patid in list(stat_dic.keys()):
        if stat_dic[patid]['drug_duration_max']<185:
            filter_6_mnth.append(patid)

    # create a 12 month filter
    filter_12_mnth = []
    for patid in list(stat_dic.keys()):
        if stat_dic[patid]['drug_duration_max']<365:
            filter_12_mnth.append(patid)
            
    return stat_dic, target_list, patids_neg_dic, filter_6_mnth, filter_12_mnth

In [None]:
# Function provided by Dennis Trimarchi, Github profile: DTrimarchi10
# url: github.com/DTrimarchi10/confusion_matrix/blob/master/cf_matrix.py

def make_confusion_matrix(cf,
                          group_names=None,
                          categories='auto',
                          count=True,
                          percent=True,
                          cbar=True,
                          xyticks=True,
                          xyplotlabels=True,
                          sum_stats=True,
                          figsize=None,
                          cmap='Blues',
                          title=None):
    '''
    This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization.
    Arguments
    ---------
    cf:            confusion matrix to be passed in
    group_names:   List of strings that represent the labels row by row to be shown in each square.
    categories:    List of strings containing the categories to be displayed on the x,y axis. Default is 'auto'
    count:         If True, show the raw number in the confusion matrix. Default is True.
    normalize:     If True, show the proportions for each category. Default is True.
    cbar:          If True, show the color bar. The cbar values are based off the values in the confusion matrix.
                   Default is True.
    xyticks:       If True, show x and y ticks. Default is True.
    xyplotlabels:  If True, show 'True Label' and 'Predicted Label' on the figure. Default is True.
    sum_stats:     If True, display summary statistics below the figure. Default is True.
    figsize:       Tuple representing the figure size. Default will be the matplotlib rcParams value.
    cmap:          Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues'
                   See http://matplotlib.org/examples/color/colormaps_reference.html
                   
    title:         Title for the heatmap. Default is None.
    '''


    # CODE TO GENERATE TEXT INSIDE EACH SQUARE
    blanks = ['' for i in range(cf.size)]

    if group_names and len(group_names)==cf.size:
        group_labels = ["{}\n".format(value) for value in group_names]
    else:
        group_labels = blanks

    if count:
        group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()]
    else:
        group_counts = blanks

    if percent:
        group_percentages = ["{0:.2%}".format(value) for value in cf.flatten()/np.sum(cf)]
    else:
        group_percentages = blanks

    box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels,group_counts,group_percentages)]
    box_labels = np.asarray(box_labels).reshape(cf.shape[0],cf.shape[1])


    # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS
    if sum_stats:
        #Accuracy is sum of diagonal divided by total observations
        accuracy  = np.trace(cf) / float(np.sum(cf))

        #if it is a binary confusion matrix, show some more stats
        if len(cf)==2:
            #Metrics for Binary Confusion Matrices
            precision = cf[1,1] / sum(cf[:,1])
            recall    = cf[1,1] / sum(cf[1,:])
            f1_score  = 2*precision*recall / (precision + recall)
            stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(
                accuracy,precision,recall,f1_score)
        else:
            stats_text = "\n\nAccuracy={:0.3f}".format(accuracy)
    else:
        stats_text = ""


    # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS
    if figsize==None:
        #Get default figure size if not set
        figsize = plt.rcParams.get('figure.figsize')

    if xyticks==False:
        #Do not show categories if xyticks is False
        categories=False


    # MAKE THE HEATMAP VISUALIZATION
    plt.figure(figsize=figsize)
    sns.heatmap(cf,annot=box_labels,fmt="",cmap=cmap,cbar=cbar,xticklabels=categories,yticklabels=categories)

    if xyplotlabels:
        plt.ylabel('True label')
        plt.xlabel('Predicted label' + stats_text)
    else:
        plt.xlabel(stats_text)
    
    if title:
        plt.title(title)