# Code with interactive widgets to analyze trained models and plot validation and roc curves

March 18, 2020


## Steps:
- For a subset of models, read all data
- Store it in a summary dictionary
- Read from the dictionary for a specific model
- Plot learning curve, roc curves and print summary

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import pickle
from ipywidgets import interact, interact_manual,fixed, SelectMultiple
import time

In [2]:
%matplotlib widget

In [3]:
## M-L modules
# import tensorflow.keras
# from tensorflow.keras import layers, models, optimizers, callbacks  # or tensorflow.keras as keras
# import tensorflow as tf

from sklearn.utils import shuffle
from sklearn.metrics import roc_curve, auc, precision_recall_curve, precision_recall_fscore_support, roc_auc_score
from tensorflow.python.keras.models import load_model


### Some basic definitions for reference
tpr= tp/(tp+fn)

fpr=fp/(fp+tn)

Missed detection rate:  $mdr=(1-tpr) $

precision = tp/(tp+fp)

recall = tpr

## Modules

In [20]:

class trained_model:
    '''
    Class to extract data of trained model
    variables: model,history, y_pred (predictions of labels), fpr, tpr, threshold, auc
    functions: f_read_stored_model, f_compute_preds
    Example objects :  (models numbers) '1', '2', etc.
    '''
    
    def __init__(self,model_name,model_save_dir,results_save_dir):
        
        ### Initialize variables
        self.tpr,self.fpr,self.threshold,self.auc1=[],[],[],None
        self.precision,self.recall,self.threshold2,self.fscore,self.auc2=[],[],[],[],None
        
        ### Read stored model (model structure and history) from files
        self.f_read_stored_model(model_name,model_save_dir,results_save_dir)
        
    def f_read_stored_model(self,model_name,model_save_dir,results_save_dir):
        '''
        Read model, history and predictions
        '''
        
        fname_model='model_{0}.h5'.format(model_name)
        fname_history='history_{0}.pickle'.format(model_name)

        # Load model and history
        print(model_save_dir+fname_model)
        self.model=load_model(model_save_dir+fname_model)
        
        with open(model_save_dir+fname_history,'rb') as f:
            self.history= pickle.load(f)
        
        # Load predictions
        fname_ypred=results_save_dir+'ypred_{0}.test'.format(model_name)
        self.y_pred=np.loadtxt(fname_ypred)

        # Load true labels
        fname_ytest=results_save_dir+'ytest_{0}.test'.format(model_name)
        self.y_test=np.loadtxt(fname_ytest)
    
    
    def f_compute_preds(self):
        '''
        Module to use model and compute quantities
        
        TPR= tp/(tp+fn)
        FPR=fp/(fp+tn)
        precision=tp/(tp+fp)
        recall=tp/(tp+fn) = TPR=sensitivty
        Missed detection rate = mdr=fn/(tp+fn)
        mdr=fn=(1-tpr)
        
        '''
        
        y_pred=self.y_pred
        test_y=self.y_test
        
        ## Calculate tpr,fpr
        self.fpr,self.tpr,self.threshold=roc_curve(test_y,y_pred,pos_label=1)
        # calculate precision-recall curve
        self.precision, self.recall, self.thresholds2 = precision_recall_curve(test_y, y_pred,pos_label=0)
#         self.precision, self.recall, self.fscore, support = precision_recall_fscore_support(test_y, y_pred, sample_weight=test_wts)
        
        # AUC1
        self.auc1= auc(self.fpr, self.tpr)
        # AUC2
        self.auc2= auc(self.recall, self.precision)
        

In [21]:

def f_plot_learning(history):
    '''Plot learning curves : Accuracy and Validation'''
    fig=plt.figure()
    # Plot training & validation accuracy values
    fig.add_subplot(2,1,1)
    xlim=len(history['acc'])
    
    plt.plot(history['acc'],label='Train',marker='o')
    plt.plot(history['val_acc'],label='Validation',marker='*')
#     plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xticks(np.arange(0,xlim,5))
    
    # Plot loss values
    fig.add_subplot(2,1,2)
    plt.plot(history['loss'],label='Train',marker='o')
    plt.plot(history['val_loss'],label='Validation',marker='*')
#     plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.xticks(np.arange(0,xlim,5))

    plt.legend(loc='best')


def f_plot_roc_1(x,y):
    '''
    Module for roc plot
    Usually, x=fpr, y=tpr
    '''
#     plt.figure()
#     plt.scatter(x,y,s=5)
    plt.semilogx(x, y,linestyle='',markersize=2,marker='*')
    ##Zooms
#     plt.xlim([10**-6,1.0])
#     plt.ylim([0,1.0])
#   ### y=x line for comparison
#     x=np.linspace(0,1,num=500)
#     plt.plot(x,x)
#     plt.xscale('log')


def f_plot_roc_2(x,y):
    '''
    Module for precision recall curve
    '''
#     plt.figure()
    plt.plot(x,y,linestyle='',markersize=2,marker='*')
    ##Zooms
    plt.xlim([0,0.1])
    plt.ylim([0,0.05])


## Read stored model

In [26]:
## Since reading data takes a bit of time, we first read a subset of models, analyze them and store essential data for plots

def f_read_all_data(model_save_dir,results_save_dir,model_name_list):
    '''
    Read stored data, plot learning and roc curves, print model summary
    '''
    dict_summary=dict.fromkeys(model_name_list,None)
    
    for model_name in model_name_list:
        obj=trained_model(model_name,model_save_dir,results_save_dir)
        obj.f_compute_preds()
        
        dict_summary[model_name]=obj
        
    return dict_summary


def f_analyze_model(model_name,dict_summary,learning_curve=True,plot_roc=True,plot_pred=True,summary=True):
    '''
    Analyze model
    '''
    
    ### Pick up data stored in summary dictionary
    obj=dict_summary[model_name]
    
    y_pred,history=obj.y_pred,obj.history
    test_y=obj.y_test
    
    ####################################
    # Plot tested model
    ### Get data for prediction comparison curves
    bkg_loc=np.where(test_y==0.0)[0]
    sig_loc=np.where(test_y==1.0)[0]
    pred_at_sig=y_pred[sig_loc]
    pred_at_bkg=y_pred[bkg_loc]
        
#     print(bkg_loc.shape,sig_loc.shape,pred_at_sig.shape,pred_at_bkg.shape)
    
    if learning_curve: 
        f_plot_learning(history)
#         plt.savefig('learning_curve.pdf')
        
    ## Plot roc curve
    if plot_roc:
        fig=plt.figure(figsize=(10,5))
        
        ### Tpr vs fpr
        fig.add_subplot(1,3,1)
        f_plot_roc_1(x=obj.fpr,y=obj.tpr)
        plt.title('Roc curve')
        plt.xlabel('fpr')
        plt.ylabel('tpr')
        
        ### Precision vs recall
        fig.add_subplot(1,3,2)
        f_plot_roc_1(x=obj.recall,y=obj.precision)
        plt.title('Precision-recall curve')
        plt.xlabel('recall')
        plt.ylabel('precision')
        
        ### Fpr vs mdr 
        #### mdr : missed detection rate. mdr=fn/(tp+fn)=1-tpr
        fig.add_subplot(1,3,3)
#         f_plot_roc_2(x=1-obj.tpr,y=obj.fpr)
        f_plot_roc_2(x=obj.fpr,y=1-obj.tpr) ## fix for inverted labels signal is 0 instead of 1

        ### Reference points in mdr plot in paper
        plt.plot(0.03,0.038,marker='s',markersize=8,color='k')
        plt.plot(0.04,0.024,marker='s',markersize=8,color='k')
        plt.plot(0.05,0.016,marker='s',markersize=8,color='k')       
        plt.xlabel('mdr')
        plt.ylabel('fpr')
        
        plt.tight_layout()

        print('Auc 1:',obj.auc1)
        print('Auc 2:',obj.auc2)

    if plot_pred:
        
        ### Plot prediction histograms 
        
        fig=plt.figure()
        fig.add_subplot(1,2,1)
        n,bins,patches=plt.hist(y_pred, density=None, bins=50,color='brown')
        plt.xlim(0,1)
        plt.title('Prediction histogram')
        
        fig.add_subplot(1,2,2)
#         n,bins,patches=plt.hist(pred_at_sig, density=None, bins=50,label='signal')
#         n,bins,patches=plt.hist(pred_at_bkg, density=None, bins=50,label='background')
        plt.hist([pred_at_sig,pred_at_bkg],bins=20,label=['sig','background'])
        
        plt.legend(loc='best')
        plt.title('Prediction distributions')
        
        plt.tight_layout()
#         plt.savefig('prediction_plots.pdf')
        
        ### Plot curves for tpr,fpr etc 
        fig=plt.figure(figsize=(6,3))
        
        fig.add_subplot(1,2,1)
        plt.plot(obj.fpr,color='r',label='fpr')
        plt.plot(obj.tpr,color='b',label='tpr')
        plt.plot((1-obj.tpr),color='y',label='mdr')  ### mdr=1-tpr
        plt.plot(obj.threshold[1:],label='threshold')
        plt.legend(loc='best')
        plt.title('FPR, TPR and threshold')

        fig.add_subplot(1,2,2)
        plt.plot(obj.precision,label='precision')
        plt.plot(obj.recall,label='recall')
        plt.plot(obj.threshold2,label='threshold2')
        plt.legend(loc='best')
        plt.title('Precision, recall and threshold')

    ## Model summary
    if summary: 
        print(np.max(obj.tpr),np.max(obj.fpr))
        print(obj.model.summary())
        pass
#         model.summary()


def f_compare_rocs(model_name,dict_summary):
    '''
    Compare roc curves for different models
    Used inside f_analyze_model
    '''
    fig,(ax1,ax2,ax3)=plt.subplots(1,3,figsize=(8,8))
    
    ax1.set_title('Roc curve')
    ax1.set_xlabel('fpr')
    ax1.set_ylabel('tpr')
    ax2.set_title('Precision-recall curve')
    ax2.set_xlabel('recall')
    ax2.set_ylabel('precision')
    ax3.set_title('FPR-MDR curve')
    ax3.set_xlabel('mdr')
    ax3.set_ylabel('fpr')
    
    for model_num in model_name:
        ### Pick up data stored in summary dictionary
        obj=dict_summary[model_num]    
        
        ### Tpr vs fpr
        x,y=obj.fpr,obj.tpr
        ax1.semilogx(x, y,linestyle='',label='model: '+model_num,markersize=2,marker='*')
        ### Precision vs recall
        x,y=obj.recall,obj.precision
        ax2.semilogy(x, y,linestyle='',label='model: '+model_num,markersize=2,marker='*')
        ### Fpr vs mdr 
        #### mdr : missed detection rate. mdr=fn/(tp+fn)=1-tpr
#         x,y=1-obj.tpr,obj.fpr
        x,y=obj.fpr,1-obj.tpr  ### fix for flipped labels

        ax3.plot(x, y,linestyle='',label='model: '+model_num,markersize=2,marker='*')

        print("Auc scores: ",model_num,obj.auc1,obj.auc2)
    
    ### Reference points in mdr plot in paper
    ax3.plot(0.03,0.038,marker='s',markersize=8,color='k')
    ax3.plot(0.04,0.024,marker='s',markersize=8,color='k')
    ax3.plot(0.05,0.016,marker='s',markersize=8,color='k')

#     ax1.set_xlim(1e-6,2)
    ax2.set_xlim(0.95,1.0)    
    ax3.set_xlim(0,0.1)
    ax3.set_ylim(0,0.05)
    
    ax1.legend(loc='best')    
    ax2.legend(loc='best')    
    ax3.legend(loc='best')
    
#     fig.savefig('comparison_roc.png')

### First store data for a subset of models

In [27]:
main_dir='/global/cfs/cdirs/dasrepo/vpa/supernova_cnn/data/results_data/results/final_summary_data_folder/'
model_save_dir=main_dir+'saved_models/'
results_save_dir=main_dir+'results_inference/'

In [28]:
lst=[1]
model_sublist=[str(i) for i in lst]
dict_summary=f_read_all_data(model_save_dir,results_save_dir,model_sublist)

/global/cfs/cdirs/dasrepo/vpa/supernova_cnn/data/results_data/results/final_summary_data_folder/saved_models/model_1.h5


In [29]:
print(dict_summary.keys())
# print(dict_summary)

dict_keys(['1'])


### Generate plots and summary 
Read from dictionary **dict_summary**

In [31]:
interact_manual(f_analyze_model,dict_summary=fixed(dict_summary),model_name=model_sublist)

interactive(children=(Dropdown(description='model_name', options=('1',), value='1'), Checkbox(value=True, desc…

<function __main__.f_analyze_model(model_name, dict_summary, learning_curve=True, plot_roc=True, plot_pred=True, summary=True)>

### Compare roc curves

In [33]:
# interact_manual(f_compare_rocs,model_name=SelectMultiple(options=model_sublist),dict_summary=fixed(dict_summary))