In [66]:
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import joblib

class sigbg_model:
    threshold = 0.5
    
    def __init__(self, model, train, test, isClassifier=True):
        self.model = model
        self.train = train
        self.test = test
        self.train_pred = train[:,:-1]
        self.train_label = train[:,-1]
        self.test_pred = test[:,:-1]
        self.test_label = test[:,-1]
        self.isClassifier = isClassifier
        self.acc = 0
        
    def get_model(self):
        return self.model
    
    def get_data(self):
        return [self.train, self.test]
    
    def fit(self):
        self.model.fit(self.train_pred, self.train_label)
    
    def predict(self, pred=None):
#         pred = pred if pred else self.test_pred
        predictions = self.model.predict(self.test_pred)
        if not self.isClassifier:
            predictions = np.where(predictions > sigbg_model.threshold,
                                          np.ones_like(predictions), 
                                          np.zeros_like(predictions))
        return predictions
        
    def accuracy(self, pred=None, display=True):
#         pred = pred if pred else self.test_pred
        predictions = self.predict(self.test_pred)
        self.acc = 100 * round(
            (len(self.test) - np.sum(np.abs(predictions - self.test_label))) / len(self.test), 4)
        if display:
            print('Accuracy: {}%'.format(self.acc))
        
    def fit_eval(self):
        self.fit()
        self.evaluate()
        
    def confusion_matrix(self, data=None):
#         pred = data[:,:-1] if data else self.test_pred
#         true = data[:,-1] if data else self.test_label
        return 100 * confusion_matrix(self.predict(self.test_pred), self.test_label) / len(self.test)

    def ppv(self):
        """positive predictive value: correctly identified signal / all identified signal"""
        conf_matrix = self.confusion_matrix()
        return round(conf_matrix[0,0]/np.sum(conf_matrix[:,0]),3)
    
    def tpr(self):
        """true positive rate: correctly identified signal / all signal"""
        conf_matrix = self.confusion_matrix()
        return round(conf_matrix[0,0]/np.sum(conf_matrix[0]),3)
    
    def ams(self, signal, background):
        conf_matrix = self.confusion_matrix()
        return signal * self.ppv() / np.sqrt(background * (1 - self.ppv()))
    
    """this stuff is just a guess"""
    def lumi_req(self, sig_cs, bg_cs, significance=5):
        return (sig_cs + bg_cs) * (significance / (self.ppv() * self.tpr() * sig_cs))**2 / 10**15
    
    def significance(self, lumi, sig_cs, bg_cs):
        return self.ppv() * self.tpr() * sig_cs * np.sqrt(lumi / (sig_cs + bg_cs))
    """this stuff is just a guess"""
        
    def ROC_plot(self):
        samples = np.linspace(0, 1, 30)
        num_bg = np.sum(self.test_label)
        num_sig = len(self.test) - num_bg
        temp = sigbg_model.threshold
        predictions = []
        for val in samples:
            sigbg_model.threshold = val
            predictions.append(self.predict())
        true_pos = [np.sum(np.logical_and(row == np.ones_like(row), row == self.test_label))/num_sig for row in predictions]
        false_pos = [np.sum(np.logical_and(row == np.ones_like(row), row != self.test_label))/num_bg for row in predictions]
        plt.scatter(false_pos, true_pos)
        sigbg_model.threshold = temp
        
    def save_model(self, filename):
        joblib.dump(self.model, filename + '.joblib')

In [7]:
def refresh_model(model):
    return sigbg_model(model.model, model.train, model.test, model.isClassifier)