In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import pickle as pkl
import scipy.integrate

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'damselfly/data/datasets')
TRAINPATH = os.path.join(PATH, 'damselfly/training/checkpoints')

def df_roc(test_output):

    noise_output = test_output[np.argwhere(test_output[:, 0] == 0).squeeze(), :]
    signal_output = test_output[np.argwhere(test_output[:, 0] == 1).squeeze(), :]
    
    TPR_array = np.zeros(1801)
    FPR_array = np.zeros(1801)
    for i, T in enumerate(np.linspace(0, 1, 1801)):

        TP = np.argwhere(signal_output[:, 2] >= T).squeeze().size
        FP = np.argwhere(noise_output[:, 2] >= T).squeeze().size

        TPR = TP / signal_output[:, 0].size
        FPR = FP / noise_output[:, 0].size

        TPR_array[i] = TPR
        FPR_array[i] = FPR
        
    return TPR_array, FPR_array
"""
Date: 7/23/2021
Description: plot distribution of mse loss from autoencoder
"""

# plot training

In [None]:
date = '210812'

match_array = np.arange(5, 11, 1) / 10
roc_list = []
for i, match in enumerate(match_array):
    roc_list.append([])
    for file in os.listdir(os.path.join(RESULTPATH, 'roc')):
        if file.split('_')[0] == date and file.split('mismatch')[-1].split('.npz')[0] == str(match):
            roc_list[i].append(np.load(os.path.join(os.path.join(RESULTPATH, 'roc', file)))['fpr'])
            roc_list[i].append(np.load(os.path.join(os.path.join(RESULTPATH, 'roc', file)))['tpr'])
            
roc_list = np.array(roc_list)
print(roc_list.shape)

In [None]:
date2 = '210802'
df_test_data_output_file = '210802_roc_210729_84_1d2sl4mt.npy'
for file in os.listdir(os.path.join(RESULTPATH, 'roc')):
    if file == df_test_data_output_file:
        df_test_out = np.load(os.path.join(RESULTPATH, 'roc', df_test_data_output_file))

In [None]:
df_tpr, df_fpr = df_roc(df_test_out)

In [None]:
sns.set_theme(context='talk', style='whitegrid')
clist = sns.color_palette('mako', n_colors=match_array.size)
fig = plt.figure(figsize=(,6))
ax = fig.add_subplot(1,1,1)

print(-1 * scipy.integrate.trapezoid(df_tpr, df_fpr))
for i in range(roc_list.shape[0]):
    ax.plot(roc_list[i, 0, :], roc_list[i, 1, :], label=f'{match_array[i]}', color=clist[i])
    print(-1 * scipy.integrate.trapezoid(roc_list[i, 1, :], roc_list[i, 0, :]))
    
ax.plot(df_fpr, df_tpr, color='tab:red', label='DF')
ax.legend(loc=(1.01,0), title='Mean Match')
#ax.set_xlim(-0.0001, 0.01)

ax.set_xlabel('FPR')
ax.set_ylabel('TPR')
ax.set_title('Compare DF ROC Curve to MF with Mismatch')

plt.tight_layout()
#plt.savefig(os.path.join(PLOTPATH, '210812_compare_df_roc_to_mf_with_mismatch.png'))

In [None]:
sns.set_theme(context='poster', style='whitegrid')
clist = sns.color_palette('deep')

fig = plt.figure(figsize=(9,6))
ax = fig.add_subplot(1,1,1)

print(-1 * scipy.integrate.trapezoid(df_tpr, df_fpr))

#for i in range(roc_list.shape[0]):
ax.plot(df_fpr, df_tpr, color=clist[1], label=f'CNN, AUC={np.round(-1 * scipy.integrate.trapezoid(df_tpr, df_fpr), 2)}')
ax.plot(roc_list[-1, 0, :], roc_list[-1, 1, :], label=f'MF, AUC={np.round(-1 * scipy.integrate.trapezoid(roc_list[-1, 1, :], roc_list[-1, 0, :]), 2)}', color=clist[0])    
ax.legend(loc=4)
#ax.set_xlim(-0.0001, 0.01)

ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve CNN vs MF')

plt.tight_layout()

dummy = np.linspace(0, 1, 20)
ax.plot(dummy, dummy, '--', color='grey')
plt.savefig(os.path.join(PLOTPATH, '210825_PANIC_roc_curve_updated.png'))