# Code to identify mis-classified points and explore their features

July 21, 2020


In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd

import subprocess as sp
import pickle
import ipywidgets as widgets
import time


In [2]:
from sklearn.metrics import roc_curve

In [3]:
%matplotlib widget

## Read test data and IDs

In [4]:
df_preds=pd.DataFrame()
# Pick any test data, all are same
test_fname='/global/project/projectdirs/dasrepo/vpa/supernova_cnn/data/results_data/results/ytest_3.test'
ID_fname='/global/cfs/cdirs/dasrepo/vpa/supernova_cnn/data/results_data/results/id_test_3.test'

df_preds['test']=np.loadtxt(test_fname,dtype=np.int16)
df_preds['ID']=np.loadtxt(ID_fname,dtype=np.int32)

df_preds.head()

Unnamed: 0,test,ID
0,1,11192965
1,1,11776878
2,1,10829754
3,1,10887873
4,0,9554136


### Explore misclassified points

In [5]:
## Key index  classification_index 
classification_key={1:'sig_bkg_strong',2:'sig_bkg_weak', 3:'sig_sig',4:'bkg_bkg',5:'bkg_sig_weak',6:'bkg_sig_strong'}

def f_classify_pred(series,col):

    if   ((series.test==1.0) & (series[col]<=0.1)) : val=1
    elif ((series.test==1.0) & ((series[col]>0.1) & (series[col]<0.5))): val=2
    elif ((series.test==1.0) & (series[col]>=0.5)): val=3
    elif ((series.test==0.0) & (series[col]<=0.5)): val=4
    elif ((series.test==0.0) & ((series[col]<0.9) & (series[col]>0.5))): val=5
    elif ((series.test==0.0) & (series[col]>=0.9)) : val=6

    else: raise SystemError
    
    return val


In [6]:
for model_num in ([3,8,9,16]):
    col='m_'+str(model_num)
    fname='/global/project/projectdirs/dasrepo/vpa/supernova_cnn/data/results_data/results/ypred_{0}.test'.format(model_num)
    df_preds[col]=np.loadtxt(fname)

    pred_class=df_preds.apply(lambda row: f_classify_pred(row,col),axis=1).values
    new_col='pred_'+col
    df_preds[new_col]=pred_class


In [7]:
# df_preds.hist(column='pred_m_8',bins=20)

In [8]:
df_preds.head(20)

Unnamed: 0,test,ID,m_3,pred_m_3,m_8,pred_m_8,m_9,pred_m_9,m_16,pred_m_16
0,1,11192965,0.999642,3,0.999418,3,0.999994,3,0.998871,3
1,1,11776878,0.997763,3,0.999965,3,0.999748,3,0.999251,3
2,1,10829754,0.98338,3,0.993709,3,0.997372,3,0.972569,3
3,1,10887873,0.999104,3,0.999811,3,0.99991,3,0.994881,3
4,0,9554136,0.011799,4,0.000692,4,0.000773,4,0.001254,4
5,1,10710441,0.999612,3,0.999733,3,0.999978,3,0.999652,3
6,1,7728956,0.241116,2,0.014566,1,0.01161,1,0.142679,2
7,1,10876213,0.996556,3,0.993969,3,0.9972,3,0.981536,3
8,0,9175547,0.04113,4,7.1e-05,4,0.042915,4,0.001054,4
9,1,8189645,0.999638,3,0.999919,3,0.99993,3,0.999063,3


### Compare models

In [None]:
df=df_preds.copy()

In [None]:
# df

In [None]:
### Proportion of models in each category
for i in np.unique(df.pred_m_3.values):
    print(i,df[df.pred_m_3==i].shape[0]/df.shape[0])

In [None]:
## Histogram of signal points strongly mis-classified as background
def f_hist_compare(df,col1,value):
    
    col_list=[ i for i in df.columns.values if i[:4]=='pred']
    cols=[col for col in col_list if col != col1] # Pick columns except the one used to splice df
    
    df[df[col1]==value].plot(kind='hist',y=cols,subplots=True,grid=True,bins=12)
    print('Total points in category',df[df[col1]==value].shape[0])

f_hist_compare(df,'pred_m_3',1)


In [None]:
## misclassified points for all models
df[(df.pred_m_3==1)&(df.pred_m_8==1)&(df.pred_m_16==1)&(df.pred_m_9==1)].shape

### View features of misclassified points

In [None]:
### Get features of IDs of test dataset
f2='/global/project/projectdirs/dasrepo/vpa/supernova_cnn/data/gathered_data/autoscan_features.3.csv'
df_features=pd.read_csv(f2,sep=',',comment='#')
df_features=df_features[df_features.ID.isin(df.ID.values)]

In [None]:
df.head(5)

In [None]:
# pd.concat([df.set_index('ID'),df_features.set_index('ID')],join='outer').reset_index()
df_merged=pd.merge(df,df_features,on='ID')

In [None]:
df=df_merged[df_merged.pred_m_3==1]

In [None]:
df.head()

In [None]:
# df_merged.columns
# df_merged.describe()

In [None]:
def f_plot_col(df,xcol,ycol='pred_m_3'):
    df.plot(x=xcol,y=ycol,kind='line',linestyle='',marker='*',color='r')


In [None]:
cols=['OBJECT_TYPE','AMP', 'A_IMAGE',
       'A_REF', 'BAND', 'B_IMAGE', 'B_REF', 'CCDID', 'COLMEDS', 'DIFFSUMRN',
       'ELLIPTICITY', 'FLAGS', 'FLUX_RATIO', 'GAUSS', 'GFLUX', 'L1',
       'LACOSMIC', 'MAG', 'MAGDIFF', 'MAGLIM', 'MAG_FROM_LIMIT', 'MAG_REF',
       'MAG_REF_ERR', 'MASKFRAC', 'MIN_DISTANCE_TO_EDGE_IN_NEW', 'N2SIG3',
       'N2SIG3SHIFT', 'N2SIG5', 'N2SIG5SHIFT', 'N3SIG3', 'N3SIG3SHIFT',
       'N3SIG5', 'N3SIG5SHIFT', 'NN_DIST_RENORM', 'NUMNEGRN', 'SCALE', 'SNR',
       'SPREADERR_MODEL', 'SPREAD_MODEL']

### Selected columns 
cols=['AMP','A_REF','B_IMAGE', 'B_REF', 'COLMEDS', 'DIFFSUMRN','ELLIPTICITY', 
          'L1', 'SCALE', 'SPREADERR_MODEL', 'SPREAD_MODEL']
widgets.interact(f_plot_col,df=widgets.fixed(df_merged),xcol=widgets.ToggleButtons(options=cols,disabled=False))

In [None]:
# widgets.interact(f_plot_col,df=widgets.fixed(df_merged),xcol=widgets.SelectionSlider(options=cols,disabled=False))

In [None]:
### Comparing 2 runs
# df_merged[(df_merged.pred_m_3!=df_merged.pred_m_8)]


In [None]:
# plt.figure()
# plt.plot(df_merged.pred_m_3,df_merged.pred_m_8,linestyle='',marker='o')

In [None]:
# H,x_edges,y_edges=np.histogram2d(df_merged.pred_m_3,df_merged.pred_m_8)
# plt.figure()
# plt.imshow(H,origin=(0,0))

### Plot roc curve

In [9]:
df=df_preds.copy()

In [10]:
def f_roc(df,label):
#     fpr,tpr,threshold=roc_curve(y_test,y_pred)
    fpr,tpr,threshold=roc_curve(df.test,df.m_3,pos_label=0)

    x,y=1-tpr,fpr
#     plt.figure()
    plt.plot(x, y,linestyle='',label=label,markersize=2,marker='*')

    # ### Reference points in mdr plot in paper
#     plt.plot(0.03,0.038,marker='s',markersize=8,color='k')
#     plt.plot(0.04,0.024,marker='s',markersize=8,color='k')
#     plt.plot(0.05,0.016,marker='s',markersize=8,color='k')

    plt.xlabel('MDR')
    plt.ylabel('FPR')
    
    #     plt.xlim(0,0.1)
#     plt.ylim(0,0.05)
    

In [11]:
df.head(5)

Unnamed: 0,test,ID,m_3,pred_m_3,m_8,pred_m_8,m_9,pred_m_9,m_16,pred_m_16
0,1,11192965,0.999642,3,0.999418,3,0.999994,3,0.998871,3
1,1,11776878,0.997763,3,0.999965,3,0.999748,3,0.999251,3
2,1,10829754,0.98338,3,0.993709,3,0.997372,3,0.972569,3
3,1,10887873,0.999104,3,0.999811,3,0.99991,3,0.994881,3
4,0,9554136,0.011799,4,0.000692,4,0.000773,4,0.001254,4


In [22]:
def f_roc(df,label):
#     fpr,tpr,threshold=roc_curve(y_test,y_pred)
    fpr,tpr,threshold=roc_curve(df.test,df.m_3,pos_label=None)

    x,y=1-tpr,fpr    
    x,y=fpr,(1-tpr)
#     plt.figure()
    plt.plot(x, y,linestyle='',label=label,markersize=2,marker='*')
    
#     plt.plot(fpr,color='r',label='fpr')
#     plt.plot(tpr,color='b',label='tpr')
# #     plt.plot((1-tpr),color='y',label='mdr')  ### mdr=1-tpr
#     plt.plot(threshold[1:],label='threshold')
    # ### Reference points in mdr plot in paper
    plt.plot(0.03,0.038,marker='s',markersize=8,color='k')
    plt.plot(0.04,0.024,marker='s',markersize=8,color='k')
    plt.plot(0.05,0.016,marker='s',markersize=8,color='k')

    plt.xlabel('MDR')
    plt.ylabel('FPR')
    plt.xlim(0,0.1)
    plt.ylim(0,0.05)
    print(threshold)



### roc ruve with all points dropped 

plt.figure()
f_roc(df,'raw')
f_roc(df[(df.pred_m_3!=6)],'drop bkg-sig')
f_roc(df[(df.pred_m_3!=1)],'drop sig-bkgnd')
f_roc(df[(df.pred_m_3!=1)&(df.pred_m_3!=6)],'drop-both')
# f_roc(df[(df.pred_m_3!=5)],'drop-weak-bkg-sig')
# f_roc(df[(df.pred_m_3!=2)],'drop-weak-sig-bkg')

plt.legend()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[1.99999994e+00 9.99999940e-01 9.99999881e-01 ... 1.69873238e-06
 1.54972076e-06 3.27825546e-07]
[1.99999994e+00 9.99999940e-01 9.99999881e-01 ... 1.69873238e-06
 1.54972076e-06 3.27825546e-07]
[1.99999994e+00 9.99999940e-01 9.99999881e-01 ... 1.69873238e-06
 1.54972076e-06 3.27825546e-07]
[1.99999994e+00 9.99999940e-01 9.99999881e-01 ... 1.69873238e-06
 1.54972076e-06 3.27825546e-07]


<matplotlib.legend.Legend at 0x2aaadfab5198>

In [20]:
plt.figure()

fpr,tpr,threshold=roc_curve(df.test,df.m_3,pos_label=0)
x,y=1-tpr,fpr
x,y=tpr,fpr
x,y=fpr,(1-tpr)
label='a'
plt.plot(x, y,linestyle='',label=label,markersize=2,marker='*')

# plt.plot(fpr,color='r',label='fpr')
# plt.plot(tpr,color='b',label='tpr')
# plt.plot(threshold[1:],label='threshold')

plt.ylabel('FPR')
plt.legend()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x2aaadf6100f0>

In [None]:
tpr

In [None]:
fpr

In [21]:
### roc ruve with 50% points dropped 
def f_get_frac_points(df,col,val,frac=0.50):
    # Get fraction of points belonging to category val from df, for column col

    a=df[df[col]!=val]
    b=df[df[col]==val].sample(frac=frac)
    df1=a.append(b)

    return df1

dfa=f_get_frac_points(df,'pred_m_3',6,0.5)
dfb=f_get_frac_points(df,'pred_m_3',1,0.5)
## Both 
dfc=pd.concat([dfa[dfa['pred_m_3']!=1],dfb[dfb['pred_m_3']!=6]]).drop_duplicates().reset_index(drop=True)
print(df.shape,dfa.shape,dfb.shape,dfc.shape)
plt.figure()
f_roc(df,'raw')
f_roc(dfa,'drop 6: bkg-sig-50pct')
f_roc(dfb,'drop 1: sig-bkg-50pct')
f_roc(dfc,'drop both')
plt.legend()


(44948, 10) (44769, 10) (44853, 10) (44674, 10)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[2.00000000e+00 1.00000000e+00 9.99999940e-01 ... 5.96046448e-08
 2.98023224e-08 0.00000000e+00]
[2.00000000e+00 1.00000000e+00 9.99999940e-01 ... 5.96046448e-08
 2.98023224e-08 0.00000000e+00]
[2.00000000e+00 1.00000000e+00 9.99999940e-01 ... 5.96046448e-08
 2.98023224e-08 0.00000000e+00]
[2.00000000e+00 1.00000000e+00 9.99999940e-01 ... 5.96046448e-08
 2.98023224e-08 0.00000000e+00]


<matplotlib.legend.Legend at 0x2aaadfa64208>

### View strongly misclassified images

In [None]:
def f_plot_grid(arr,cols=16,fig_size=(15,5)):
    ''' Plot a grid of images
    '''
    size=arr.shape[0]    
    rows=int(np.ceil(size/cols))
    print(rows,cols)

    fig,axarr=plt.subplots(rows,cols,figsize=fig_size, gridspec_kw = {'wspace':0, 'hspace':0})
    if rows==1: axarr=np.reshape(axarr,(rows,cols))
    if cols==1: axarr=np.reshape(axarr,(rows,cols))
    
    for i in range(min(rows*cols,size)):
        row,col=int(i/cols),i%cols
        try: 
            axarr[row,col].imshow(arr[i],origin='lower',  extent = [0, 128, 0, 128])
        # Drop axis label
        except Exception as e:
            print('Exception:',e)
            pass
        temp=plt.setp([a.get_xticklabels() for a in axarr[:-1,:].flatten()], visible=False)
        temp=plt.setp([a.get_yticklabels() for a in axarr[:,1:].flatten()], visible=False)
        
        

In [None]:
### Extract file name
f2='/global/project/projectdirs/dasrepo/vpa/supernova_cnn/data/gathered_data/summary_label_files.csv'
df2=pd.read_csv(f2,sep=',',comment='#')
# df2.head(20)

In [None]:
# fig=plt.figure()
#     # Plot training & validation accuracy values

# for count1,iD in enumerate(IDs[:1]):
#     for count2,key in enumerate(keys):
#         print(iD,key)
#         df_temp=df3[(df3.ID==iD)&(df3.filename.str.startswith(key))]
#         fle=df_temp['file path'].values[0]
#         img=plt.imread(fle)
# #         display(df_temp)
#         idx1,idx2=count1+1,count2+1
#         print(idx1,idx2)
#         fig.add_subplot(idx1*idx2,idx1,idx2)
#         plt.imshow(img)


In [None]:
df_preds.shape

In [None]:
category=6
num_images=100

IDs=df_preds[df_preds.pred_m_3==category].ID.values
print(IDs.shape)
np.random.shuffle(IDs)
IDs=IDs[:num_images]

df3=df2[df2.ID.isin(IDs)]
df3.shape
# del(df2)


keys=['temp','srch','diff']
img=np.array([plt.imread(df3[(df3.ID==iD)&(df3.filename.str.startswith(key))]['file path'].values[0]) for iD in IDs for key in keys ])
print(img.shape)
df3.head()

In [None]:
# [(iD,key) for iD in IDs[:10] for key in keys ]

In [None]:
# f_plot_grid(img,cols=9,fig_size=(9,5))
t1=time.time()
f_plot_grid(img,cols=3,fig_size=(3,100))
fname='category{0}.pdf'.format(str(category))
plt.savefig(fname)
plt.close()
t2=time.time()
print(t2-t1)


In [None]:
df_merged.columns

In [None]:
df_merged[df_merged.ID.isin(IDs)][['OBJECT_TYPE','m_3','pred_m_3']]