# Code to identify mis-classified points and explore their features

July 21, 2020


In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd

import subprocess as sp
import pickle
import ipywidgets as widgets
import time


In [2]:
from sklearn.metrics import roc_curve

In [3]:
%matplotlib widget

![image.png](attachment:3608f0cd-0326-41c5-a6d0-09f5e59b0a9c.png)

## Read test data and predictions

In [4]:
main_dir='/global/cfs/cdirs/dasrepo/vpa/supernova_cnn/data/results_data/results/final_summary_data_folder/'

In [5]:
f1=main_dir+'sample_test_data/temp_bigger_data/input_labels_y.txt'
f2=main_dir+'results_inference/y_large_pred.txt'

In [6]:
df=pd.DataFrame()
df['label']=np.loadtxt(f1,dtype=np.int16)
df['pred']=np.loadtxt(f2,dtype=np.float32)

print(df.shape)
df.head()


(5000, 2)


Unnamed: 0,label,pred
0,1,0.999773
1,0,3.1e-05
2,1,0.999562
3,1,0.99583
4,1,0.999947


## Histograms

In [8]:
df.plot(kind='hist',y=['label','pred'],subplots=True,grid=True,bins=12)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

array([<AxesSubplot:ylabel='Frequency'>, <AxesSubplot:ylabel='Frequency'>],
      dtype=object)

In [9]:
### Prediction histograms
#### Comparing predction

plt.figure()
column='pred'
# Predictions at points where label is 1
sig_preds=df[df.label==1][column].values
# Predictions at points where label is 0
bkg_preds=df[df.label==0][column].values
plt.hist([sig_preds,bkg_preds],bins=20,alpha=0.5,label=[column+':Artifact=1',column+':non-artifacts=0'])
plt.legend()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x2aaadadea8e0>

## Plot roc curve

####The ROC curve flip argument

Since True=1 in keras, but Non-Artifact is defined as 0 in the paper, 
Tp <-> Tn and Fn <-> Fp

Actual roc curve is x=mdr, y=fpr


mdr=Fn/(Tp+Fn) -> Fp/(Tn+Fp) = fpr
fpr=Fp/(Tn+Fp) -> Fn/(Tp+Fn) = mdr = fnr = 1-tpr


Hence we get an mdr <-> fpr flip


In [10]:
def f_roc(df,col,label,fig_type='mdr'):
    '''
    Function to plot the roc curve
    '''
    fpr,tpr,threshold=roc_curve(df.label,df[col],pos_label=None)
    
    if fig_type=='mdr':
#         x,y=1-tpr,fpr 
        x,y=fpr,(1-tpr)     # The roc curve flip: mdr-> fpr; fpr-> (1-tpr)
        plt.plot(x, y,linestyle='',label=label,markersize=2,marker='*')
        # ### Reference points in mdr plot in paper
        
    
        rf_2015_lst=[(0.03,0.038),(0.04,0.024),(0.05,0.016)]
        for count,a in enumerate(rf_2015_lst):
            if count==0:
                plt.plot(a[0],a[1],marker='s',markersize=8,color='k',label='RF_2015')
            else: 
                plt.plot(a[0],a[1],marker='s',markersize=8,color='k')

        plt.xlabel('MDR')
        plt.ylabel('FPR')
        plt.xlim(0,0.1)
        plt.ylim(0,0.05)
        
    if fig_type=='tpr':
        x,y=(1-tpr),(1-fpr)     # The roc curve flip: fpr-> (1-tpr); tpr-> (1-fpr)
        plt.plot(x, y,linestyle='',label=label,markersize=2,marker='*')
        plt.xlabel('FPR')
        plt.ylabel('TPR')
    
    plt.legend()
    

In [11]:
plt.figure()
f_roc(df,'pred','pred',fig_type='mdr')
plt.title('MDR roc curve')

plt.figure()
f_roc(df,'pred','pred',fig_type='tpr')
plt.xscale('log')
plt.xlim(0,0.04)
plt.title('TPR roc curve')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Invalid limit will be ignored.
  plt.xlim(0,0.04)


Text(0.5, 1.0, 'TPR roc curve')