In [1]:
import math
import pickle
from operator import itemgetter

In [2]:
# Read files
with open('mimic_2_freq_result_seq2seq_e84_s0_jacc0.16841_acc0.00097.pkl', 'rb') as f:
    rec_leap = pickle.load(f)

In [26]:
with open('clean_data_leap/disease_to_icd10.pkl', 'rb') as f:
        disease_to_icd10 = pickle.load(f)

In [20]:
# Measure validity of recommendation: NDCG algorithm
def dcg(result, ref):
    dcg = 0
    for i in range(len(result)):
        if result[i] in ref:
            dcg += (2 ** 1 - 1)/math.log(i + 2, 2)
        else:
            dcg += (2 ** 0 - 1)/math.log(i + 2, 2)
    return dcg

def idcg(result, ref):
    idcg = 0
    for i in range(len(result)):
        if i < len(ref):
            idcg += (2 ** 1 - 1)/math.log(i + 2, 2)
        else:
            idcg += (2 ** 0 - 1)/math.log(i + 2, 2)
    return idcg

def ndcg(result, ref):
    res_dcg = dcg(result, ref)
    res_idcg = idcg(result, ref)
    if res_idcg <= 0:
        return 0
    ndcg = res_dcg / res_idcg
    return ndcg

In [24]:
sum_ndcg = 0
for visit in rec_leap:
    drug_list = visit[1]
    recommendation = list(visit[2])
    
    rec_ndcg = ndcg(recommendation, drug_list)
    sum_ndcg += rec_ndcg

mean_ndcg = sum_ndcg / len(rec_leap)
print(mean_ndcg)

0.305405937342264


In [12]:
diag_len_jaccard = {}
for visit in rec_leap:
    drug_list = set(visit[1])
    recommendation = visit[2]
    
    intersection = list(drug_list.intersection(recommendation))
    union = list(drug_list.union(recommendation))
    jaccard = len(intersection) / len(union)
    
    diag_list = visit[0]
    diag_len = len(diag_list)
    if diag_len in diag_len_jaccard:
        diag_len_jaccard[diag_len].append(jaccard)
    else:
        diag_len_jaccard[diag_len] = [jaccard]


In [14]:
for length in diag_len_jaccard:
    jac_list = diag_len_jaccard[length]
    avg = sum(jac_list) / len(jac_list)
    diag_len_jaccard[length] = avg

In [15]:
sorted_len_jacaard = sorted(diag_len_jaccard.items(), key=itemgetter(0))
print(sorted_len_jacaard)

[(1, 0.032078503314777775), (2, 0.053592606208891454), (3, 0.07656724869953788), (4, 0.12097220323725622), (5, 0.14390405052713634), (6, 0.16918761785144265), (7, 0.16169166525502252), (8, 0.16772838001089413), (9, 0.15699161340681309), (10, 0.1731417769571914), (11, 0.1658214867907907), (12, 0.17555354510117854), (13, 0.1715552775969158), (14, 0.17367915793325475), (15, 0.17439515306874145), (16, 0.1899685157136412), (17, 0.1917558339017834), (18, 0.20280995247765027), (19, 0.2191713178933478), (20, 0.2136701627481766), (21, 0.22360469888124154), (22, 0.22438346850792956), (23, 0.22715995651655183), (24, 0.24663996115432074), (25, 0.23143953363772465), (26, 0.2592273014165731), (27, 0.2361670405391722), (28, 0.2617100607018859), (29, 0.24123552460977582), (30, 0.29023276989993924), (31, 0.26577015469504744), (32, 0.25469244816821013), (33, 0.2568534783036851), (34, 0.2604956085694013), (35, 0.26888873707928646), (36, 0.2606012975863163), (37, 0.2676564495530013), (38, 0.31355932203389

In [18]:
with open('recommendation_leap/leap_len_jaccard.txt', "w") as f:
    for pair in sorted_len_jacaard:
        f.write('%-.4f\n' %(pair[1]))

In [32]:
# Measure jaccard coefficient of recommendation
def jaccard_coef(result, ref):
    result, ref = set(result), set(ref)
    intersection = result & ref
    union = result | ref
    jaccard = len(intersection) / float(len(union))
    return jaccard

In [33]:
sum_recall = sum_ndcg = sum_jaccard = 0
avg_length = 0
index = count = 0

In [34]:
# Write files
with open('recommendation_leap/rec_test_LEAP.txt', "w") as f:
    for visit in rec_leap:
        real_drugs = list()
        for diag in visit[0]:
            if diag in disease_to_icd10:
                f.write("%s " %(disease_to_icd10[diag]))
            else:
                f.write("%s " %(diag))
        
        f.write("\nMIMIC: ")
        for drug in visit[1]:
            f.write("%s " %(drug))
            real_drugs.append(drug)

        drug_rec = list(visit[2])
        f.write("\nLEAP Recommendation: ")
        for drug in drug_rec:
            f.write("%s / " %(drug))
        
        recall_LEAP = 0
        for drug in drug_rec:
            if drug in real_drugs:
                recall_LEAP += 1
        f.write("\nRecall Rate: ")
        recall_rate = recall_LEAP * 1.0 / (len(real_drugs))
        f.write("%.4f\t" %(recall_rate))
        sum_recall += recall_rate

        f.write("NDCG: ")
        ndcg_LEAP = ndcg(drug_rec, real_drugs)
        f.write("%.4f\t" %(ndcg_LEAP))
        sum_ndcg += ndcg_LEAP

        f.write("Jaccard: ")
        jaccard_LEAP = jaccard_coef(drug_rec, real_drugs)
        f.write("%.4f\t" %(jaccard_LEAP))
        sum_jaccard += jaccard_LEAP

        avg_length += len(drug_rec)
        
        index += 1
        count += 1
        f.write("\n---------------------------------------------------------------------------------------\n")

    f.write("%-10s%-10s" %(" ", "Single"))
    f.write("\n%-10s%-10.4f" %("Recall", sum_recall/count))
    f.write("\n%-10s%-10.4f" %("NDCG", sum_ndcg/count))
    f.write("\n%-10s%-10.4f" %("Jaccard", sum_jaccard/count))
    f.write("\n%-10s%-10.4f" %("Length", avg_length/count))