In [None]:
# default_exp multi_model

In [None]:
# hide
from nbdev.showdoc import *

# multi_model

> Class representing a list of `bcml_object` models, parameterized by a number

In [None]:
# export
import numpy as np
import scipy.interpolate

In [None]:
# export
class multi_model:
    """
    Represents a list of `bcml_object` models, parameterized by a number
    """
    
    def __init__(self, models, index=None):
        self.models = models
        self.index = index if index is not None else [i for i in range(len(models))]
        
    def get_sigs(self, signals, background, tpr=None, fpr=None, sepbg=False):
        return [
            model.significance(signal, background, tpr=tpr, fpr=fpr, sepbg=sepbg) 
            for model, signal in zip(self.models, signals)]
    
    def index2logsigF(self, signals, background, tpr=None, fpr=None, sepbg=False):
        sigs = self.get_sigs(signals, background, tpr=None, fpr=None, sepbg=False)
        return scipy.interpolate.interp1d(self.index, np.log10(sigs), kind='cubic')
    
    def index2thresh_opt_improvementF(self, signal, background, tpr=None, fpr=None, sepbg=False, preds=None, labels=None):
        sigs = get_sigs(signal, background, tpr=tpr, fpr=fpr, sepbg=sepbg)
        opt_thresh_sigs = [
            model.best_threshold(signal, background, sepbg=sepbg, preds=preds, labels=labels) for model in self.models]
        improvement = [opt_thresh_sig/sig for opt_thresh_sig, sig in zip(sigs, opt_thresh_sigs)]
        return scipy.interpolate.interp1d(self.index, improvement, kind='cubic')
        
    def index2feature_importanceFs(self, features, num_features):
        feature_importancess = [model.sorted_feature_importance(features)[:num_features] for model in self.models]
        features = [[row[0] for row in feature_importances] for feature_importances in feature_importancess]
        # features present in the top "num_features" for each model in self.models
        present_features = list(set(features[0]).intersection(*features[1:]))
        importance_by_feature = [
            [feature, 
             [feature_importancess[i][feature_importancess[i].index(feature)][1] for i in range(len(self.models))]] 
            for feature in present_features]
        return [
             [feature, scipy.interpolate.interp1d(self.index, importances, kind='cubic')] 
             for feature, importances in importance_by_feature]