In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'damselfly/data/datasets')

"""
Date: 6/25/2021
Description: plots the roc curves calculated for the dfcnn models trained on different noise temperature data.
"""

In [None]:
result_list = os.listdir(os.path.join(RESULTPATH, 'roc'))

In [None]:
def load_result(file):
    temp = float(file.split('temp')[-1].split('_')[0])
    dset = file.split('dset_name_')[-1].split('_')[0]
    result = {}
    result_data = np.load(file)
    for i,key in enumerate(result_data):
        result[key] = result_data[key]
    return temp, dset, result

def load_all_results(result_files):
    all_results = []
    for file in result_files:
        temp, dset, result = load_result(os.path.join(RESULTPATH, 'roc', file))
        result_dict = {'T': temp, 'dset': dset }
        for i,key in enumerate(result):
            result_dict[key] = result[key]
        all_results.append(result_dict)
    return pd.DataFrame(all_results)

all_results = load_all_results(result_list)

In [None]:
temps = np.sort(all_results['T'].unique())
classes = np.sort(all_results['class_ind'].iloc[0])
tpr, fpr = {}, {}

for temp in temps:
    tpr[temp] = {}
    fpr[temp] = {}
    for iclass in classes:
        tpr[temp][iclass] = all_results[all_results['T'] == temp]['tpr'].item()[iclass]
        fpr[temp][iclass] = all_results[all_results['T'] == temp]['fpr'].item()[iclass]

In [None]:
figname = '210625_plot_dfcnn_roc_vs_noise_temp.png'

sns.set_theme(context='talk', style='whitegrid', palette='mako')

fig = plt.figure(figsize=(17,6))
ax1 = plt.subplot(1,2,1)
ax2 = plt.subplot(1,2,2)

auc = []
for i, temp in enumerate(temps):
    #for j, jclass in enumerate(classes):
    
    ax1.plot((fpr[temp][0] + fpr[temp][1]) / 2, (tpr[temp][0] + tpr[temp][1]) / 2, '-', label=f'{temp}')
    ax2.plot(temp, 
             abs(scipy.integrate.simpson((tpr[temp][0] + tpr[temp][1]) / 2, x=(fpr[temp][0] + fpr[temp][1]) / 2)), 
             'o', label=f'{temp}')
    auc.append(abs(scipy.integrate.simpson((tpr[temp][0] + tpr[temp][1]) / 2, x=(fpr[temp][0] + fpr[temp][1]) / 2)))
    
ax1.plot(np.linspace(0,1,10), np.linspace(0, 1, 10), '--', color='tab:gray')
ax1.legend(loc=4, title='Noise Temp(K)')
ax1.set_title('ROC Curve')
ax1.set_xlabel('False Positive Rate')
ax1.set_ylabel('True Positive Rate')



ax2.set_title('Area Under the Curve')
ax2.set_xlabel('Noise Temp (K)')
ax2.set_ylabel('AUC')
#ax2.legend(loc=3)

plt.savefig(os.path.join(PLOTPATH, figname))

plt.show()
