In [68]:
import numpy as np
import matplotlib.pyplot as plt

class sigbg_model:
    threshold = 0.5
    
    def __init__(self, model, train, test):
        self.model = model
        self.train = train
        self.test = test
        self.train_pred = train[:,:-1]
        self.train_label = train[:,-1]
        self.test_pred = test[:,:-1]
        self.test_label = test[:,-1]
        
    def get_model(self):
        return self.model
    
    def get_data(self):
        return [self.train, self.test]
    
    def fit(self):
        self.model.fit(self.train_pred, self.train_label)
    
    def predict(self, data=None):
        data = data or self.test_pred
        predictions = self.model.predict(data)
        predictions_binary = np.where(predictions > sigbg_model.threshold,
                                      np.ones_like(predictions), 
                                      np.zeros_like(predictions))
        return predictions_binary
        
    def evaluate(self):
        predictions = self.predict()
        print(
            'Accuracy: {}%'.format(
                round((len(self.test) - np.sum(np.abs(predictions - self.test_label)))
                      /len(self.test), 4) * 100))
        
    def fit_eval(self):
        self.fit()
        self.evaluate()
        
    def ROC_plot(self):
        samples = np.linspace(0, 1, 30)
        num_bg = np.sum(self.test_label)
        num_sig = len(self.test) - num_bg
        temp = sigbg_model.threshold
        predictions = []
        for val in samples:
            sigbg_model.threshold = val
            predictions.append(self.predict())
        true_pos = [np.sum(np.logical_and(row == np.ones_like(row), row == self.test_label))/num_sig for row in predictions]
        false_pos = [np.sum(np.logical_and(row == np.ones_like(row), row != self.test_label))/num_bg for row in predictions]
        plt.scatter(false_pos, true_pos)