In [1]:
import scipy
import numpy as np
import pandas as pd
from sklearn.base import clone
from sklearn.model_selection import check_cv
from sklearn.base import is_classifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import get_scorer, check_scoring
from sklearn.preprocessing import LabelBinarizer
from scipy.special import xlogy

In [2]:

def gen_xy(
    model,
    iv_corr,
    n_obs):
    n_ivs = 10
    mean = np.zeros((n_ivs,))
    cov = np.block([[(iv_corr * np.ones((n_ivs - 3, n_ivs - 3)) + 
         (1 - iv_corr) * np.eye(n_ivs - 3)), np.zeros((7, 3))],
                    [np.zeros((3, 7)), np.eye(3)]])
    x = np.random.multivariate_normal(
      mean = mean, 
      cov = cov, 
      size = n_obs)
    if model == "linear":
        coef = np.array([.1, .2, .3, .4]).reshape(4, -1)
        cov_signal = cov[0:4, 0:4]
        error_var = 1 - (coef.T @ cov_signal @ coef).item()
        x_signal = x[:,0:4]
    else:
        coef = np.array([.3, .3, .3, .4]).reshape(4, -1)
        sd_quad = np.sqrt(2)
        sd_prod = np.sqrt(1 + iv_corr**2)
        a = (2 * (iv_corr**2)) / (sd_quad * sd_quad)
        b = (2 * (iv_corr**2)) / (sd_quad * sd_prod)
        cov_signal = np.array(
            [[ 1.  ,  0.  , 0.  ,  0.  ],
             [ 0.  ,  1.  ,  a,  b],
             [0.  ,  a,  1.  ,  b],
             [ 0.  ,  b,  b,  1.  ]])
        error_var = 1 - (coef.T @ cov_signal @ coef).item()
        x_signal = np.concatenate(
            (x[:,0:1], 
             (x[:,0:1]**2)  / sd_quad, 
             (x[:,1:2]**2) / sd_quad,
             (x[:,2:3] * x[:,3:4]) / sd_prod), 
            axis = 1)
    error = np.random.normal(
      loc = 0.0, 
      scale = np.sqrt(error_var), 
      size = (n_obs, ))
    y = (x_signal @ coef).reshape(-1,) + error
    r2 = 1 - error_var
    return x, y, r2




In [3]:
class CrossEstimator():
    def __init__(
        self,
        estimator,
        cv = None
    ):
        self.estimator = estimator
        self.cv = cv

    def fit(
        self, 
        X, 
        y
    ):
        cv = self.cv
        cv = check_cv(
            cv, y, 
            classifier = is_classifier(self.estimator))
        cv_indexes = []
        estimators = []
        for train_index, test_index in cv.split(X):     
            estimator = clone(self.estimator)
            _ = estimator.fit(X[train_index,], y[train_index])
            estimator = estimator
            estimators.append(estimator)               
            cv_indexes.append((train_index, test_index))
        self.X_, self.y_ = X.copy(), y.copy()
        self.estimators_ = estimators
        self.cv_indexes_ = cv_indexes
        self.n_splits_ = cv.get_n_splits()
        if is_classifier(self.estimator):
            label_binarizer = LabelBinarizer()
            _ = label_binarizer.fit(y)
            self.label_binarizer_ = label_binarizer
        return self

    def predict(
        self, 
        X,
        split = None
    ):
        if split is None:
            preds = np.array(
                [estimator.predict(X) 
                 for estimator in self.estimators_])
            if is_classifier(self.estimator):
                pred = np.apply_along_axis(
                    lambda x: np.argmax(np.bincount(x)),
                    axis = 0,
                    arr = preds)
            else:
                pred = preds.mean(axis = 0)
        else:
            pred = self.estimators_[split].predict(X)
        return pred

    def predict_proba(
        self, 
        X,
        split = None
    ):
        if split is None:
            preds = np.array(
                [estimator.predict_proba(X)
                 for estimator in self.estimators_])
            pred = preds.mean(axis = 0)
        else:
            pred = self.estimators_[split].predict_proba(X)
        return pred

    def predict_log_proba(
        self, 
        X,
        split = None
    ):
        if split is None:
            preds = np.array(
                [estimator.predict_log_proba(X)
                 for estimator in self.estimators_])
            pred = preds.mean(axis = 0)
        else:
            pred = self.estimators_[split].predict_log_proba(X)
        return pred
            
    def decision_function(
        self, 
        X,
        split = None
    ):
        if split is None:
            preds = np.array(
                [estimator.decision_function(X) 
                 for estimator in self.estimators_])
            pred = preds.mean(axis = 0)
        else:
            pred = self.estimators_[split].decision_function(X)
        return pred
            
    
    def sample(
        self,
        X,
        split = None,
        n_repeats = None, 
        random_state = None
    ):
        rng = np.random.default_rng(random_state)
        if is_classifier(self.estimator):
            pred = self.predict_proba(X, split)
            if n_repeats is None:            
                rv = rng.multinomial(1, pred)
                rv = self.label_binarizer_.inverse_transform(rv)
            else:
                rv = rng.multinomial(
                    1, pred, (n_repeats, len(pred)))
                rv = np.array(
                    [self.label_binarizer_.inverse_transform(rv_i) 
                     for rv_i in rv])
        else:
            if split is None:
                targets = self._targets()
                preds = self._preds()
                residual = np.concatenate(
                    [target - pred 
                     for target, pred in zip(targets, preds)])
            else:
                train_index, test_index = self.cv_indexes_[split]
                if hasattr(self.X_, "iloc"):
                    feature = self.X_.iloc[test_index, :]
                    target = self.y_.iloc[test_index]
                else:
                    feature = self.X_[test_index, :]
                    target = self.y_[test_index]
                pred = self.predict(feature, split)     
                residual = target - pred
            pred = self.predict(X, split)
            if len(pred) > len(residual):
                replace = True
            else:
                replace = False
            if n_repeats is None:
                rv = pred + rng.choice(residual, len(pred), replace)
            else:
                rv = pred + np.array(
                    [rng.choice(residual, len(pred), replace) 
                     for repeat in range(n_repeats)])
        return rv

    def _features(
        self
    ):
        features = []
        for train_index, test_index in self.cv_indexes_:
            if hasattr(self.X_, "iloc"):
                feature = self.X_.iloc[test_index, :]
            else:
                feature = self.X_[test_index, :]
            features.append(feature)
        return features
    
    def _targets(
        self,
        binarize = None
    ):
        targets = []
        for train_index, test_index in self.cv_indexes_:
            if hasattr(self.y_, "iloc"):
                target = self.y_.iloc[test_index]
            else:
                target = self.y_[test_index]
            if is_classifier(self.estimator):
                if binarize is True:
                    target = self.label_binarizer_.transform(target)
                    if target.shape[1] == 1:
                        target = np.append(1 - target, target, axis=1)
            targets.append(target)
        return targets

    def _preds(
        self,
        response_method = "predict"
    ):
        preds = []
        for split, (train_index, test_index) in enumerate(self.cv_indexes_):
            predict_func = getattr(self, response_method)
            if hasattr(self.X_, "iloc"):
                feature = self.X_.iloc[test_index, :]
            else:
                feature = self.X_[test_index, :]
            pred = predict_func(feature, split)
            preds.append(pred)
        return preds

    def _rvs(
        self,
        n_repeats = None,
        random_state = None
    ):
        rvs = []
        for split, (train_index, test_index) in enumerate(self.cv_indexes_):
            if hasattr(self.X_, "iloc"):
                feature = self.X_.iloc[test_index, :]
            else:
                feature = self.X_[test_index, :]
            rv = self.sample(
                feature, 
                split, 
                n_repeats,
                random_state)
            rvs.append(rv)
        return rvs
        


In [4]:
class CIT():
    def __init__(
        self, 
        learner,
        remover,
        column,
        infer_method = "perm",
        loss_func = None,
        n_repeats = None,
        random_state = None
    ):
        if loss_func is None:
            if is_classifier(learner.estimator):
                loss_func = "log_loss"
            else:
                loss_func = "mean_squared_error"

        if isinstance(column, str):
            if hasattr(learner, "columns"):
                column = learner.columns.get_loc(column)

        if infer_method == "perm":
            if n_repeats is None:
                n_repeats = 2000

        self.learner = learner
        self.remover = remover
        self.column = column
        self.loss_func = loss_func
        self.infer_method = infer_method
        self.n_repeats = n_repeats
        self.random_state = random_state

    def infer(
        self):
        learner = self.learner
        remover = self.remover
        column = self.column
        infer_method = self.infer_method
        loss_func = self.loss_func
        n_repeats = self.n_repeats
        random_state = self.random_state
        
        if loss_func == "log_loss":
            def log_loss(target, pred):
                eps = np.finfo(pred.dtype).eps
                pred = np.clip(pred, eps, 1 - eps)
                loss = -xlogy(target, pred).sum(axis=1)
                return loss
            loss_func = log_loss
            binarize = True
            response_method = "predict_proba" 

        if loss_func == "mean_squared_error":
            def mean_squared_error(target, pred):
                loss = (target - pred)**2
                return loss
            loss_func = mean_squared_error
            binarize = False
            response_method = "predict"
            
        l_features = learner._features()
        l_targets = learner._targets(binarize)
        l_preds = learner._preds(response_method)
        l_losses = [loss_func(l_target, l_pred)
                   for l_target, l_pred in zip(l_targets, l_preds)]
        
        r_features = l_features
        r_rvs = remover._rvs(
            n_repeats = n_repeats,
            random_state = random_state)
        r_losses = []
        
        if infer_method == "perm":
            def _r_loss_repeat(r_rv_repeat):
                if hasattr(r_feature, "iloc"):
                    r_feature.iloc[:, column] = r_rv_repeat
                else:
                    r_feature[:,column] = r_rv_repeat
                r_pred_repeat = learner.predict(r_feature, split)
                r_loss_repeat = loss_func(l_target, r_pred_repeat)
                return r_loss_repeat
            
            null_values = []  
            for split, (l_loss, l_target, r_feature, r_rv) in enumerate(
                zip(l_losses, l_targets, r_features, r_rvs)):
                r_loss = np.apply_along_axis(
                    _r_loss_repeat,
                    axis = 1,
                    arr = r_rv)
                null_value = (l_loss - r_loss).mean(axis = 1)
                null_values.append(null_value)
                r_losses.append(r_loss.mean(axis = 0))
        else:
            null_values = None
            for split, (l_target, r_feature, r_rv) in enumerate(
                zip(l_targets, r_features, r_rvs)):
                if hasattr(r_feature, "iloc"):
                    r_feature.iloc[:, column] = r_rv
                else:
                    r_feature[:,column] = r_rv
                r_pred = learner.predict(r_feature, split)
                r_loss = loss_func(l_target, r_pred)
                r_losses.append(r_loss)
        
        self.learner_losses_ = l_losses
        self.rival_losses_ = r_losses
        self.null_values_ = null_values

    def summarize(
        self,
        agg_method = None
    ):
        if agg_method is None:
            summary = pd.DataFrame(
                {"estimate": self._estimates(),
                 "std_error": self._std_errors(),
                 "p_value": self._p_values()}
            )
            summary.index.name = "split" 
        
        return summary
        

    def _estimates(
        self
    ):
        l_losses = self.learner_losses_
        r_losses = self.rival_losses_
        estimates = [l_loss.mean() - r_loss.mean()
               for l_loss, r_loss in zip(l_losses, r_losses)]
        return estimates

    def _std_errors(
        self
    ):
        infer_method = self.infer_method
        if infer_method == "perm":
            null_values = self.null_values_
            std_errors = [null_value.std() 
                          for null_value in null_values]
        else:
            l_losses = self.learner_losses_
            r_losses = self.rival_losses_
            std_errors = [(l_loss - r_loss).std() / np.sqrt(len(l_loss)) 
                          for l_loss, r_loss in zip(l_losses, r_losses)]
        return std_errors

    def _p_values(
        self
    ):
        infer_method = self.infer_method
        if infer_method == "perm":
            null_values = self.null_values_
            p_values = [(null_value > 0).mean() 
                        for null_value in null_values]
        else:
            estimates = self._estimates()
            std_errors = self._std_errors()
            p_values = [scipy.stats.norm.cdf(estimate / std_error) 
                        for estimate, std_error in zip(estimates, std_errors)]
        return p_values

        

In [5]:
class RIT():
    def __init__(
        self, 
        learner,
        rival,
        infer_method = "normal",
        loss_func = None,
        n_repeats = None,
        random_state = None
    ):
        if loss_func is None:
            if is_classifier(learner.estimator):
                loss_func = "log_loss"
            else:
                loss_func = "mean_squared_error"

        if infer_method == "perm":
            if n_repeats is None:
                n_repeats = 2000

        self.learner = learner
        self.rival = rival
        self.loss_func = loss_func
        self.infer_method = infer_method
        self.n_repeats = n_repeats
        self.random_state = random_state


    def infer(
        self):
        learner = self.learner
        rival = self.rival
        infer_method = self.infer_method
        loss_func = self.loss_func
        n_repeats = self.n_repeats
        random_state = self.random_state
        
        if loss_func == "log_loss":
            def log_loss(target, pred):
                eps = np.finfo(pred.dtype).eps
                pred = np.clip(pred, eps, 1 - eps)
                loss = -xlogy(target, pred).sum(axis=1)
                return loss
            loss_func = log_loss
            binarize = True
            response_method = "predict_proba" 

        if loss_func == "mean_squared_error":
            def mean_squared_error(target, pred):
                loss = (target - pred)**2
                return loss
            loss_func = mean_squared_error
            binarize = False
            response_method = "predict"

        l_targets = learner._targets(binarize)
        l_preds = learner._preds(response_method)
        l_losses = [loss_func(l_target, l_pred)
                   for l_target, l_pred in zip(l_targets, l_preds)]
        
        r_targets = rival._targets(binarize)
        r_preds = rival._preds(response_method)
        r_losses = [loss_func(r_target, r_pred)
                   for r_target, r_pred in zip(r_targets, r_preds)]
        
        if infer_method == "perm":
            null_values = [] 
            rng = np.random.default_rng(random_state)
            for l_loss, r_loss in zip(l_losses, r_losses):
                estimate = l_loss.mean() - r_loss.mean()
                paired_loss = np.column_stack([l_loss, r_loss])
                null_value = np.array([
                    estimate - np.diff(
                        rng.permuted(
                            paired_loss, 
                            axis = 1).mean(
                            axis = 0)).item() 
                    for repeat in range(n_repeats)])
                null_values.append(null_value)
        else:
            null_values = None
        
        self.learner_losses_ = l_losses
        self.rival_losses_ = r_losses
        self.null_values_ = null_values

    def summarize(
        self,
        agg_method = None
    ):
        if agg_method is None:
            summary = pd.DataFrame(
                {"estimate": self._estimates(),
                 "std_error": self._std_errors(),
                 "p_value": self._p_values()}
            )
            summary.index.name = "split" 
        
        return summary


    def _estimates(
        self
    ):
        l_losses = self.learner_losses_
        r_losses = self.rival_losses_
        estimates = [l_loss.mean()  - r_loss.mean()
               for l_loss, r_loss in zip(l_losses, r_losses)]
        return estimates

    def _std_errors(
        self
    ):
        infer_method = self.infer_method
        if infer_method == "perm":
            null_values = self.null_values_
            std_errors = [null_value.std() 
                          for null_value in null_values]
        else:
            l_losses = self.learner_losses_
            r_losses = self.rival_losses_
            std_errors = [(l_loss - r_loss).std() / np.sqrt(len(l_loss)) 
                          for l_loss, r_loss in zip(l_losses, r_losses)]
        return std_errors

    def _p_values(
        self
    ):
        infer_method = self.infer_method
        if infer_method == "perm":
            null_values = self.null_values_
            p_values = [(null_value > 0).mean() 
                        for null_value in null_values]
        else:
            estimates = self._estimates()
            std_errors = self._std_errors()
            p_values = [scipy.stats.norm.cdf(estimate / std_error) 
                        for estimate, std_error in zip(estimates, std_errors)]
        return p_values

        

In [6]:
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold, RepeatedKFold
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
splitter = RepeatedKFold(
    n_splits = 2, 
    n_repeats = 3, 
    random_state = 0)
learner = CrossEstimator(
    GridSearchCV(
            estimator = RandomForestRegressor(), 
            param_grid = {
                "max_features": [3, 6, 9]}, 
            cv = 4),
    cv = splitter)
rival = CrossEstimator(
    GridSearchCV(
            estimator = RandomForestRegressor(), 
            param_grid = {
                "max_features": [3, 6, 9]}, 
            cv = 4),
    cv = splitter)
remover = CrossEstimator(
    GridSearchCV(
            estimator = RandomForestRegressor(), 
            param_grid = {
                "max_features": [3, 6, 9]}, 
            cv = 4),
    cv = splitter)

In [7]:
X, y, r2 = gen_xy(
        model = "linear",
        iv_corr = .8,
        n_obs= 200)

In [8]:
column = 1
_ = learner.fit(X, y)
_ = rival.fit(np.delete(X, column, axis = 1), y)
_ = remover.fit(
    np.delete(X, column, axis = 1), 
    X[:,column])

In [9]:
rit = RIT(
    learner, 
    rival, 
    infer_method = "perm")
_ = rit.infer()
rit.summarize()

Unnamed: 0_level_0,estimate,std_error,p_value
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0.002556,0.011504,0.592
1,-0.011015,0.010345,0.1465
2,-0.02782,0.010467,0.0045
3,-0.003819,0.010363,0.356
4,-0.015133,0.010132,0.0725
5,-0.00535,0.008492,0.2615


In [10]:
rit = RIT(
    learner, 
    rival, 
    infer_method = "normal")
_ = rit.infer()
rit.summarize()

Unnamed: 0_level_0,estimate,std_error,p_value
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0.002556,0.011557,0.587513
1,-0.011015,0.010451,0.145949
2,-0.02782,0.010005,0.002712
3,-0.003819,0.010288,0.355246
4,-0.015133,0.010017,0.065442
5,-0.00535,0.008652,0.268161


In [11]:
cit = CIT(
    learner, 
    remover, 
    column,
    infer_method = "perm")
_ = cit.infer()
cit.summarize()

Unnamed: 0_level_0,estimate,std_error,p_value
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,-0.020521,0.007544,0.0015
1,-0.011597,0.008462,0.0905
2,-0.013881,0.00619,0.0115
3,-0.01749,0.007885,0.013
4,-0.025849,0.009572,0.002
5,-0.020276,0.008542,0.008


In [12]:
cit = CIT(
    learner, 
    remover, 
    column,
    infer_method = "normal")
_ = cit.infer()
cit.summarize()

Unnamed: 0_level_0,estimate,std_error,p_value
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,-0.013882,0.008926,0.059945
1,-0.020936,0.010826,0.026559
2,0.001111,0.008767,0.550424
3,-0.018776,0.011302,0.048328
4,-0.024897,0.013682,0.034403
5,-0.028159,0.00932,0.001258
