In [None]:
%matplotlib inline
import matplotlib 
import numpy as np
import matplotlib.pyplot as plt

import pandas as pd

import os

import cv2

import matplotlib.gridspec as gridspec

DETECTIONS = "/Users/vgenty/Desktop/rcnn/detections_4/"
VALID = "/Users/vgenty/Desktop/rcnn/valid/"
dets = os.listdir(DETECTIONS)

In [None]:
idx_to_class = {1 : 'eminus',
                2 : 'proton',
                3 : 'pizero',
                4 : 'muminus'}

In [None]:
results = {'eminus' : [], 'proton' : [], 'pizero' : [], 'muminus' : []}

for ix, det in enumerate(dets):
    d_truth  = det.split("_")[-1][:-4]
    dets_df  = pd.read_csv(DETECTIONS+det,names=['prob','xmin','ymin','xmax','ymax'])
    top      = dets_df.sort(columns='prob',ascending=False).iloc[0]
    results[d_truth].append({'name':top.name,'prob':top.prob})

In [None]:
classifieds = { 'eminus' : {},
                'proton' : {},
                'pizero' : {},
                'muminus': {}}

for thresh in np.arange(0,1.01,0.01):
    for particle in classifieds :
        rp_o = [0,0,0,0]
        par = results[particle]
        for res in par:
            
            if float(res['prob'] < thresh):
                continue
                
            if thresh not in classifieds[particle]:
                classifieds[particle][thresh] = [0,0,0,0]
                
            classifieds[particle][thresh][res['name'] - 1] += 1


p_df = pd.DataFrame(classifieds)

In [None]:
matplotlib.rcParams['font.size'] = 16

for par in results:

    thresh_0_sum =  float(np.array(p_df[par].ix[0.0]).sum())

    fig,ax = plt.subplots(figsize=(10,6))
    for i in xrange(4):
        ax.plot(p_df.index.values,
                (p_df[par].iloc[:].str[i].values / thresh_0_sum)*100,
                '-o',
                label=idx_to_class[i+1])
    
    ax.set_ylabel('Classified % by RCNN'.format(par),fontweight='bold')
    ax.set_xlabel('Detection prob',fontweight='bold')
    ax.set_title("Truth: {}\nSample Size: {}".format(par,thresh_0_sum))
    ax.set_ylim(0,100.0)
    ax.legend(loc='best',fontsize=12)
    plt.savefig('eff_4_class_{}.pdf'.format(par), format='pdf', dpi=1000)
    plt.show()

In [None]:
matplotlib.rcParams['font.size'] = 16
matplotlib.rcParams['font.family'] = 'serif'
for par in results:
    
    sums = np.array([float(np.sum(j)) for j in p_df[par].values])
    fig,_ = plt.subplots(figsize=(16,8))
    
    gs = gridspec.GridSpec(2, 1,
                           width_ratios=[1,1],
                           height_ratios=[1,2])
    
    ax = plt.subplot(gs[0])
    ax.plot(p_df.index.values,sums,'-',color='black',lw=2)
    ax.set_ylabel('N Images'.format(par),fontweight='bold')
    ax.set_title("Purity of truth: {}\n".format(par))
    ax.set_xticklabels([])
    ax = plt.subplot(gs[1])
    for i in xrange(4):
        ax.plot(p_df.index.values,
                (p_df[par].iloc[:].str[i].values / sums)*100.0,
                '-o',
                label=idx_to_class[i+1])
    
    ax.set_ylabel('Classified % by RCNN'.format(par),fontweight='bold')
    ax.set_xlabel('Detection prob treshold',fontweight='bold')
   
    ax.legend(loc='best',fontsize=12)
    plt.savefig('purity_4_class_{}.pdf'.format(par), format='pdf', dpi=1000)
    plt.tight_layout()
    plt.show()

In [None]:
for ix, det in enumerate(dets):
    if det.split("_")[2][:-4] != 'pizero':
        continue
    
    im = cv2.imread(VALID + det.split("_")[1] + ".JPEG")
    im = im[:, :, (2, 1, 0)]

    
    dets_df  = pd.read_csv(DETECTIONS+det,names=['prob','xmin','ymin','xmax','ymax'])
    top      = dets_df.sort(columns='prob',ascending=False).iloc[0]   
    if float(top.prob) < 0.6:
        continue
    fig,ax = plt.subplots(figsize=(10,10))
    plt.imshow(im)
    #for ix,top in dets_df.iterrows():
    bbox     = [top.xmin,top.ymin,top.xmax,top.ymax]
    d_rpn_ix = top.name
    d_rpn    = idx_to_class[d_rpn_ix]
    
    ax.add_patch(plt.Rectangle((bbox[0], bbox[1]),
                                     bbox[2] - bbox[0],
                                     bbox[3] - bbox[1], fill=False,
                                    edgecolor='red', linewidth=3))
    ax.text(bbox[0], bbox[1] - 2,
             '{:s} {:.3f}'.format(d_rpn, top.prob),
                 bbox=dict(facecolor='blue', alpha=0.5),
                 fontsize=14, color='white')
    ax.set_title("Truth: {}".format(det.split("_")[2][:-4]),fontweight='bold')
    plt.show()
    