In [170]:
import math
import pickle
from operator import itemgetter

In [87]:
# Read files
with open('clean_data_leap/visit_mappings.pkl', 'rb') as f:
        visit_mappings = pickle.load(f)

with open('clean_data_leap/sorted_drug_list.pkl', 'rb') as f:
        sorted_drug_list = pickle.load(f)

with open('recommendation_leap/single_rec_50.pkl', 'rb') as f:
        single = pickle.load(f)

with open('recommendation_leap/mul_rec_50.pkl', 'rb') as f:
        multiple = pickle.load(f)

with open('logistic_models_top538/test.pkl', 'rb') as f:
        test = pickle.load(f)

with open('clean_data_leap/diag_order.pkl', 'rb') as f:
        diag_order = pickle.load(f)

In [202]:
with open('clean_data_leap/disease_to_icd10.pkl', 'rb') as f:
        disease_to_icd10 = pickle.load(f)

In [88]:
valid_drugs = list()
for ele in sorted_drug_list[:50]:
    valid_drugs.append(ele[0])

In [89]:
test_visits = [visit_mappings[i] for i in test]

In [92]:
# Measure validity of recommendation: NDCG algorithm
def dcg(result, ref):
    dcg = 0
    for i in range(len(result)):
        if result[i] in ref:
            dcg += (2 ** 1 - 1)/math.log(i + 2, 2)
        else:
            dcg += (2 ** 0 - 1)/math.log(i + 2, 2)
    return dcg

def idcg(result, ref):
    idcg = 0
    for i in range(len(result)):
        if i < len(ref):
            idcg += (2 ** 1 - 1)/math.log(i + 2, 2)
        else:
            idcg += (2 ** 0 - 1)/math.log(i + 2, 2)
    return idcg

def ndcg(result, ref):
    res_dcg = dcg(result, ref)
    res_idcg = idcg(result, ref)
    if res_idcg <= 0:
        return 0
    ndcg = res_dcg / res_idcg
    return ndcg

In [93]:
# Measure jaccard coefficient of recommendation
def jaccard_coef(result, ref):
    result, ref = set(result), set(ref)
    intersection = result & ref
    union = result | ref
    jaccard = len(intersection) / float(len(union))
    return jaccard

In [203]:
with open('recommendation_leap/thresholds_50_50.pkl', 'rb') as f:
    thresholds = pickle.load(f)

In [204]:
drug_thresholds = {}
for index, threshold in enumerate(thresholds):
    drug_name = sorted_drug_list[index][0]
    if not isinstance(threshold, float):
        threshold = 0.0
    drug_thresholds[drug_name] = threshold

In [205]:
sum_recall_single = sum_recall_mul = 0
sum_ndcg_single = sum_ndcg_mul = 0
sum_jaccard_single = sum_jaccard_mul = 0

In [206]:
avg_length_single = avg_length_mul = 0

In [207]:
mul_diag_len_jaccard = {}
single_diag_len_jaccard = {}

In [208]:
# Write files
with open('recommendation_leap/rec_test_top50_threshold50.txt', "w") as f:
    index = count = 0
    for visit in test_visits:
        real_drugs = list()
        for diag in visit['diag_list']:
            if diag in disease_to_icd10:
                f.write("%s " %(disease_to_icd10[diag]))
            else:
                f.write("%s " %(diag))
        
        f.write("\nMIMIC: ")
        for drug in visit['drug_list']:
            if drug in valid_drugs:
                f.write("%s " %(drug))
                real_drugs.append(drug)

        if len(real_drugs) == 0:
            index += 1
            continue
        
        recall_single = recall_mul = 0
        
        single_rec = {}
        drug_rec = list()
        for diag in visit['diag_list']:
            for drug in single[diag]:
                if sorted_drug_list[drug[0]][0] in single_rec and drug[1] <= single_rec[sorted_drug_list[drug[0]][0]]:
                    continue
                single_rec[sorted_drug_list[drug[0]][0]] = drug[1]
        f.write("\nSingle Factor Recommendation: ")
        sorted_single_rec = sorted(single_rec.items(), key=itemgetter(1), reverse=True)
        for rec in sorted_single_rec:
            if rec[0] in drug_thresholds:
                if rec[1] >= drug_thresholds[rec[0]]:
                    f.write("%s, %.4f / " %(rec[0], rec[1]))
                    drug_rec.append(rec[0])

        for drug in drug_rec:
            if drug in real_drugs:
                recall_single += 1
        f.write("\nRecall Rate: ")
        recall_rate_single = recall_single * 1.0 / (len(real_drugs))
        f.write("%.4f\t" %(recall_rate_single))
        sum_recall_single += recall_rate_single

        f.write("NDCG: ")
        ndcg_single = ndcg(drug_rec, real_drugs)
        f.write("%.4f\t" %(ndcg_single))
        sum_ndcg_single += ndcg_single

        f.write("Jaccard: ")
        jaccard_single = jaccard_coef(drug_rec, real_drugs)
        f.write("%.4f\t" %(jaccard_single))
        sum_jaccard_single += jaccard_single

        avg_length_single += len(drug_rec)

        f.write("\nMultiple Factor Recommendation: ")
        multi_rec = multiple[index]
        drug_rec_mul = list()
        for rec in multi_rec:
            key = sorted_drug_list[rec[0]][0]
            if key in drug_thresholds:
                if rec[1] >= drug_thresholds[key]:
                    f.write("%s, %.4f / " %(sorted_drug_list[rec[0]][0], rec[1]))
                    drug_rec_mul.append(sorted_drug_list[rec[0]][0])

        for drug in drug_rec_mul:
            if drug in real_drugs:
                recall_mul += 1
        f.write("\nRecall Rate: ")
        recall_rate_mul = recall_mul * 1.0 / (len(real_drugs))
        f.write("%.4f\t" %(recall_rate_mul))
        sum_recall_mul += recall_rate_mul

        f.write("NDCG: ")
        ndcg_mul = ndcg(drug_rec_mul, real_drugs)
        f.write("%.4f\t" %(ndcg_mul))
        sum_ndcg_mul += ndcg_mul

        f.write("Jaccard: ")
        jaccard_mul = jaccard_coef(drug_rec_mul, real_drugs)
        f.write("%.4f\t" %(jaccard_mul))
        sum_jaccard_mul += jaccard_mul

        avg_length_mul += len(drug_rec_mul)
        
        diag_len = len(visit['diag_list'])
        if diag_len in mul_diag_len_jaccard:
            mul_diag_len_jaccard[diag_len].append(jaccard_mul)
        else:
            mul_diag_len_jaccard[diag_len] = [jaccard_mul]
        if diag_len in single_diag_len_jaccard:
            single_diag_len_jaccard[diag_len].append(jaccard_single)
        else:
            single_diag_len_jaccard[diag_len] = [jaccard_single]
        
        index += 1
        count += 1
        f.write("\n---------------------------------------------------------------------------------------\n")

    f.write("%-10s%-10s%-10s" %(" ", "Single", "Multiple"))
    f.write("\n%-10s%-10.4f%-10.4f" %("Recall", sum_recall_single/count, sum_recall_mul/count))
    f.write("\n%-10s%-10.4f%-10.4f" %("NDCG", sum_ndcg_single/count, sum_ndcg_mul/count))
    f.write("\n%-10s%-10.4f%-10.4f" %("Jaccard", sum_jaccard_single/count, sum_jaccard_mul/count))
    f.write("\n%-10s%-10.4f%-10.4f" %("Length", avg_length_single/count, avg_length_mul/count))

In [190]:
for length in mul_diag_len_jaccard:
    jac_list = mul_diag_len_jaccard[length]
    avg = sum(jac_list) / len(jac_list)
    mul_diag_len_jaccard[length] = avg

In [191]:
sorted_mul_len_jacaard = sorted(mul_diag_len_jaccard.items(), key=itemgetter(0))
print(sorted_mul_len_jacaard)

[(1, 0.1874692821798085), (2, 0.2059424716250963), (3, 0.2318094603339046), (4, 0.2578659768568519), (5, 0.28738186746379585), (6, 0.28514814799472094), (7, 0.27937887302354775), (8, 0.30830278012562395), (9, 0.28732815383488636), (10, 0.35066589689132915), (11, 0.36308895582216144), (12, 0.3798562379077928), (13, 0.3796786631444079), (14, 0.37673880599515086), (15, 0.3776795613491196), (16, 0.4131026304833303), (17, 0.40095773572665816), (18, 0.405564683994855), (19, 0.4428282803662178), (20, 0.4299398514711865), (21, 0.43732310757538106), (22, 0.474949383305282), (23, 0.453037220126799), (24, 0.476696234729102), (25, 0.4459807949043059), (26, 0.5081994922507135), (27, 0.45913467693263776), (28, 0.5148535572983789), (29, 0.5427317002884986), (30, 0.5859230693105925), (31, 0.5187492778147428), (32, 0.6209771834829884), (33, 0.4896499556466945), (34, 0.5652777777777778), (35, 0.5532432024367507), (36, 0.46781708211143697), (37, 0.7095959595959596), (38, 0.6875), (39, 0.5859518450981865)

In [197]:
with open('recommendation_leap/mul_len_jaccard.txt', "w") as f:
    for pair in sorted_mul_len_jacaard:
        f.write('%-.4f\n' %(pair[1]))

In [194]:
for length in single_diag_len_jaccard:
    jac_list = single_diag_len_jaccard[length]
    avg = sum(jac_list) / len(jac_list)
    single_diag_len_jaccard[length] = avg

In [195]:
sorted_single_len_jacaard = sorted(single_diag_len_jaccard.items(), key=itemgetter(0))
print(sorted_single_len_jacaard)

[(1, 0.1874692821798085), (2, 0.18567572728773027), (3, 0.18402648307621683), (4, 0.19651618867656667), (5, 0.20266024960764908), (6, 0.20834630004305243), (7, 0.187437358300273), (8, 0.20476292815560426), (9, 0.20229700590760746), (10, 0.21755716070256786), (11, 0.2230874168710433), (12, 0.22835966476549124), (13, 0.23074741078663408), (14, 0.22986372539994557), (15, 0.22994141191353412), (16, 0.24830290464939608), (17, 0.23970816596737998), (18, 0.24497223467602078), (19, 0.2484644412553355), (20, 0.2512396583391832), (21, 0.24750466614066588), (22, 0.2770976790175407), (23, 0.2638967825989434), (24, 0.27493487402711764), (25, 0.26895024289120933), (26, 0.2758062712199876), (27, 0.28162454433908346), (28, 0.292451064581297), (29, 0.30779931181458786), (30, 0.3529864473645033), (31, 0.3000970738697776), (32, 0.32423962453613064), (33, 0.29510304366991913), (34, 0.2727272727272727), (35, 0.3234707171123119), (36, 0.24372412008281572), (37, 0.396640826873385), (38, 0.4897959183673469), 

In [198]:
with open('recommendation_leap/single_len_jaccard.txt', "w") as f:
    for pair in sorted_single_len_jacaard:
        f.write('%-.4f\n' %(pair[1]))