# Code with interactive widgets to analyze trained models and plot validation and roc curves
Sept 3, 2019


## Steps:
- For a subset of models, read all data
- Store it in a summary dictionary
- Read from the dictionary for a specific model
- Plot learning curve, roc curves and print summary

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import h5py

import subprocess as sp
import pickle
from ipywidgets import interact, interact_manual,fixed, SelectMultiple
import time

In [2]:
## M-L modules
# import tensorflow.keras
# from tensorflow.keras import layers, models, optimizers, callbacks  # or tensorflow.keras as keras
# import tensorflow as tf

from sklearn.utils import shuffle
from sklearn.metrics import roc_curve, auc, roc_auc_score, precision_recall_curve, precision_recall_fscore_support
from tensorflow.keras.models import load_model


In [3]:
%matplotlib widget

## Modules

In [4]:
class data_set:
    ''' Simple class to store the data set 
    variables: labels, images, weights
    Modules: Getting summary, Plotting
    Example objects: train data, test data, validation data
    '''
    
    def __init__(self, filename):
        self.filename=filename
        self.f_get_data()
        print("Created object from file ",filename)
        
    def f_get_data(self):
        '''
        Function to get data from hdf5 files into images, labels and weights.
        '''
        try: 
            hf = h5py.File(self.filename)

        except Exception as e:
            print(e)
            print("Name of file",self.filename)
            raise SystemError

        idx=None  ### Index for the case when the data is too large and you want to read in a slice
        self.images = np.expand_dims(hf['all_events']['hist'][:idx], -1)
        self.labels = hf['all_events']['y'][:idx]
        weights = hf['all_events']['weight'][:idx]
        self.weights = np.log(weights+1)


class trained_model:
    '''
    Class to extract data of trained model
    variables: model,history, y_pred (predictions of labels), fpr, tpr, threshold, auc
    functions: f_read_stored_model, f_compute_preds
    Example objects :  (models numbers) '1', '2', etc.
    '''
    
    def __init__(self,model_name,model_save_dir):
        
        self.tpr,self.fpr,self.threshold,self.auc=[],[],[],None
        self.precision,self.recall,self.threshold2,self.fscore,self.auc2=[],[],[],[],None
        self.f_read_stored_model(model_name,model_save_dir)
        
    def f_read_stored_model(self,model_name,model_save_dir):
        '''
        Read model, history and predictions
        '''
        
        fname_model='model_{0}.h5'.format(model_name)
        fname_history='history_{0}.pickle'.format(model_name)

        # Load model and history
        self.model=load_model(model_save_dir+fname_model)
        
        with open(model_save_dir+fname_history,'rb') as f:
            self.history= pickle.load(f)
        
        # Load predictions
        fname_ypred=model_save_dir+'ypred_{0}.test'.format(model_name)
        self.y_pred=np.loadtxt(fname_ypred)
    
    def f_compute_preds(self,test_data):
        '''
        Module to use model and compute 
        '''
        
        y_pred=self.y_pred
        test_y,test_wts=test_data.labels,test_data.weights
#         print(test_x.shape,test_y.shape,y_pred.shape,test_wts.shape)

        ## roc curve
        self.fpr,self.tpr,self.threshold=roc_curve(test_y,y_pred,sample_weight=test_wts)
        # AUC 
        self.auc= auc(self.fpr, self.tpr)
        
        # calculate precision-recall curve
        self.precision, self.recall, self.thresholds2 = precision_recall_curve(test_y, y_pred, sample_weight=test_wts)
#         self.precision, self.recall, self.fscore, support = precision_recall_fscore_support(test_y, y_pred, sample_weight=test_wts)
        
        # AUC2
        self.auc2= auc(self.recall, self.precision)
        

In [5]:

def f_plot_learning(history):
    '''Plot learning curves : Accuracy and Validation'''
    fig=plt.figure()
    # Plot training & validation accuracy values
    fig.add_subplot(2,1,1)
    xlim=len(history['acc'])
    
    plt.plot(history['acc'],label='Train',marker='o')
    plt.plot(history['val_acc'],label='Validation',marker='*')
#     plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xticks(np.arange(0,xlim,5))
    
    # Plot loss values
    fig.add_subplot(2,1,2)
    plt.plot(history['loss'],label='Train',marker='o')
    plt.plot(history['val_loss'],label='Validation',marker='*')
#     plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.xticks(np.arange(0,xlim,5))

    plt.legend(loc='best')


def f_plot_roc_curve(fpr,tpr):
    '''
    Module for roc plot and printing AUC
    '''
#     plt.figure()
#     plt.scatter(fpr,tpr,s=5)
#     plt.semilogx(fpr, tpr)
    plt.semilogx(fpr, tpr,linestyle='',markersize=2,marker='*')

  # Zooms
    plt.xlim([10**-6,1.0])
    plt.ylim([0,1.0])
#   ###y=x line for comparison
    x=np.linspace(0,1,num=500)
    plt.plot(x,x)
#     plt.xscale('log')
#     plt.xlim(1e-10,1e-5)


## Read stored model

In [17]:
## Since reading data takes a bit of time, we first read a subset of models, analyze them and store essential data for plots

def f_real_all_data(model_save_dir,model_name_list,test_data):
    '''
    Read stored data, plot learning and roc curves, print model summary
    '''
    
    dict_summary=dict.fromkeys(model_name_list,None)
    
    for model_name in model_name_list:
        obj=trained_model(model_name,model_save_dir)
        obj.f_compute_preds(test_data)
        
        dict_summary[model_name]=obj
        
    return dict_summary


def f_analyze_model(model_name,dict_summary,test_data,learning_curve=True,plot_roc=True,plot_pred=False,summary=False):
    '''
    Analyze model
    '''

    ### Pick up data stored in summary dictionary
    obj=dict_summary[model_name]
    
    y_pred,history=obj.y_pred,obj.history
    fpr,tpr,threshold,auc=obj.fpr,obj.tpr,obj.threshold,obj.auc
    test_y,test_wts=test_data.labels,test_data.weights
#     print(test_y.shape,y_pred.shape,test_wts.shape)
    
    ####################################
    # Plot tested model
    ### Get data for prediction comparison curves
    bkg_loc=np.where(test_y==0.0)[0]
    sig_loc=np.where(test_y==1.0)[0]
    pred_at_sig=y_pred[sig_loc]
    pred_at_bkg=y_pred[bkg_loc]
        
#     print(bkg_loc.shape,sig_loc.shape,pred_at_sig.shape,pred_at_bkg.shape)
    
    if learning_curve: 
        f_plot_learning(history)
        plt.savefig('learning_curve.pdf')
        
    ## Plot roc curve
    if plot_roc:
        fig=plt.figure()
        
        fig.add_subplot(2,2,1)
        f_plot_roc_curve(fpr,tpr)
        plt.title('Roc curve')
        
        fig.add_subplot(2,2,2)
        f_plot_roc_curve(obj.precision,obj.recall)
        plt.title('Precision-recall curve')
        
        print('Auc 1:',obj.auc)
        print('Auc 2:',obj.auc2)
        
        fig.add_subplot(2,2,3)
        n,bins,patches=plt.hist(y_pred, density=None, bins=50)
        plt.xlim(0,1)
        plt.title('Prediction histogram')
        
        fig.add_subplot(2,2,4)
#         n,bins,patches=plt.hist(pred_at_sig, density=None, bins=50,label='signal')
#         n,bins,patches=plt.hist(pred_at_bkg, density=None, bins=50,label='background')
        plt.hist([pred_at_sig,pred_at_bkg],bins=20,label=['sig','background'])
        
        plt.legend(loc='best')
        plt.title('Prediction distributions ')
        
        plt.tight_layout()
#         plt.hist([x, y], bins, label=['x', 'y'])
#         plt.savefig('prediction_plots.pdf')
    
    if plot_pred:
        fig=plt.figure()
        
        fig.add_subplot(1,3,1)
        plt.plot(fpr,color='r',label='fpr')
        plt.plot(tpr,color='b',label='tpr')
        plt.plot(threshold[1:],label='threshold')
        plt.legend(loc='best')
        plt.title('FPR, TPR and threshold')
        
        fig.add_subplot(1,3,2)
        plt.plot(sig_loc,marker='*',label='signal')
        plt.plot(bkg_loc,marker='D',label='background')
#         plt.plot(test_y,label='y test')
        plt.legend(loc='best')
        plt.title('ypred vs ytest')
        
        fig.add_subplot(1,3,3)
        plt.plot(obj.precision,label='precision')
        plt.plot(obj.recall,label='recall')
        plt.plot(obj.threshold2,label='threshold2')
        plt.legend(loc='best')
        plt.title('Precision, recall and threshold')
        
    ## Model summary
    if summary: 
        print(np.max(tpr),np.max(fpr))
        print(obj.model.summary())
        pass
#         model.summary()


def f_compare_rocs(model_name,dict_summary):
    '''
    Analyze model
    '''
    
    ### Pick up data stored in summary dictionary
#     print(model_name,type(model_name))    
    plt.figure()
    
    for model_num in model_name:
        obj=dict_summary[model_num]
        
        fpr,tpr,threshold,auc=obj.fpr,obj.tpr,obj.threshold,obj.auc
        print(auc)
        ## Plot roc curve
#         plt.scatter(fpr,tpr,label='model: '+model_num,s=10,marker='*')
        plt.semilogx(fpr, tpr,linestyle='',label='model: '+model_num,markersize=4,marker='*')
        plt.legend(loc='best')


### First store data for a subset of models

In [7]:
data_dir='/global/project/projectdirs/dasrepo/vpa/atlas_cnn/data/RPVSusyData/'
test_data=data_set(data_dir+'test.h5')

Created object from file  /global/project/projectdirs/dasrepo/vpa/atlas_cnn/data/RPVSusyData/test.h5


In [8]:
# model_save_dir='/global/project/projectdirs/dasrepo/vpa/atlas_cnn/results/2_runs_Sept13_modified/'
# model_save_dir='/global/project/projectdirs/dasrepo/vpa/atlas_cnn/results/3_runs_Oct1_models_with_strides/'
# model_save_dir='/global/project/projectdirs/dasrepo/vpa/atlas_cnn/results/4_old_test_set_best_models/'
# model_save_dir='/global/project/projectdirs/dasrepo/vpa/atlas_cnn/results/5_final_set_best_models_fail/'
model_save_dir='/global/project/projectdirs/dasrepo/vpa/atlas_cnn/results/8_final_set_best_models/'

# lst=[1,2,3,4,5,6,7,8,9,14,15,16]
lst=[0,1,2,3,4,5,6,10,11]
# lst=[1,2]
model_sublist=[str(i) for i in lst ]

dict_summary=f_real_all_data(model_save_dir,model_sublist,test_data)

W1016 16:43:41.008540 46912496621568 deprecation.py:506] From /global/homes/v/vpa/.conda/envs/v_py3/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W1016 16:43:41.009476 46912496621568 deprecation.py:506] From /global/homes/v/vpa/.conda/envs/v_py3/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W1016 16:43:41.010192 46912496621568 deprecation.py:506] From /global/homes/v/vpa/.conda/envs/v_py3/lib/python3.6/site-packages/tensorflow/python/ops/init_ops

In [9]:
print(dict_summary.keys())
# dir(dict_summary.keys())
# #print(dict_summary)


dict_keys(['0', '1', '2', '3', '4', '5', '6', '10', '11'])


### Generate plots and summary 
Read from dictionary **dict_summary**

In [10]:
# f_analyze_model('1',dict_summary,test_data,learning_curve=False,summary=False,plot_roc=True,plot_pred=True)

In [18]:
interact_manual(f_analyze_model,dict_summary=fixed(dict_summary),model_name=model_sublist,test_data=fixed(test_data))

interactive(children=(Dropdown(description='model_name', options=('0', '1', '2', '3', '4', '5', '6', '10', '11…

<function __main__.f_analyze_model(model_name, dict_summary, test_data, learning_curve=True, plot_roc=True, plot_pred=False, summary=False)>

### Compare roc curves

In [16]:
# f_compare_rocs(('1','2'),dict_summary)
interact_manual(f_compare_rocs,model_name=SelectMultiple(options=model_sublist),dict_summary=fixed(dict_summary))


interactive(children=(SelectMultiple(description='model_name', options=('0', '1', '2', '3', '4', '5', '6', '10…

<function __main__.f_compare_rocs(model_name, dict_summary)>

In [13]:
# f_analyze_model('1',dict_summary)
# f_compare_rocs(('1','2'),dict_summary)