In [32]:
import scipy
import numpy as np
from sklearn.base import clone
from sklearn.model_selection import check_cv
from sklearn.base import is_classifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import get_scorer, check_scoring
from sklearn.preprocessing import LabelBinarizer
from scipy.special import xlogy
from numpy.random import multinomial, choice, permutation

In [40]:
class CrossEstimator():
    def __init__(
        self,
        estimator,
        cv = None
    ):
        self.estimator = estimator
        self.cv = cv

    def fit(
        self, 
        X, 
        y
    ):
        cv = check_cv(self.cv, y, classifier = is_classifier(self.estimator))
        cv_indexes = []
        estimators = []
        for train_index, test_index in cv.split(X):     
            estimator = clone(self.estimator)
            _ = estimator.fit(X[train_index,], y[train_index])
            if hasattr(estimator, "best_estimator_"):
                estimator = estimator.best_estimator_
            estimators.append(estimator)               
            cv_indexes.append((train_index, test_index))
        self.X_, self.y_ = X.copy(), y.copy()
        self.estimators_ = estimators
        self.cv_indexes_ = cv_indexes
        if is_classifier(self.estimator):
            label_binarizer = LabelBinarizer()
            _ = label_binarizer.fit(y)
            self.label_binarizer_ = label_binarizer
        return self

    def predict(
        self, 
        X,
        split = None
    ):
        if split is None:
            preds = np.array(
                [estimator.predict(X) 
                 for estimator in self.estimators_])
            if is_classifier(self.estimator):
                pred = np.apply_along_axis(
                    lambda x: np.argmax(np.bincount(x)),
                    axis = 0,
                    arr = preds)
            else:
                pred = preds.mean(axis = 0)
        else:
            pred = self.estimators_[split].predict(X)
        return pred

    def predict_proba(
        self, 
        X,
        split = None
    ):
        if split is None:
            preds = np.array(
                [estimator.predict_proba(X) 
                 for estimator in self.estimators_])
            pred = preds.mean(axis = 0)
        else:
            pred = self.estimators_[split].predict_proba(X)
        return pred

    def predict_log_proba(
        self, 
        X,
        split = None
    ):
        if split is None:
            preds = np.array(
                [estimator.predict_log_proba(X) 
                 for estimator in self.estimators_])
            pred = preds.mean(axis = 0)
        else:
            pred = self.estimators_[split].predict_log_proba(X)
        return pred
            
    def decision_function(
        self, 
        X,
        split = None
    ):
        if split is None:
            preds = np.array(
                [estimator.decision_function(X) 
                 for estimator in self.estimators_])
            pred = preds.mean(axis = 0)
        else:
            pred = self.estimators_[split].decision_function(X)
        return pred
            
    
    def sample(
        self,
        X,
        split = None
    ):
        if is_classifier(self.estimator):
            pred = self.predict_proba(X, split)
            rv = np.apply_along_axis(
                lambda pred_i: multinomial(
                    1, pred_i),
                axis = 1,
                arr = pred)
            rv = self.label_binarizer_.inverse_transform(rv)
        else:
            if split is None:
                preds = []
                targets = []
                for i, (train_index, test_index) in enumerate(self.cv_indexes_):
                    target = self.y_[test_index]
                    pred = self.estimators_[i].predict(self.X_[test_index, ]) 
                    targets.append(target)
                    preds.append(pred)
                res = np.concatenate(
                    [target - pred for target, pred in zip(targets, preds)])
            else:
                train_index, test_index = self.cv_indexes_[split]
                target = self.y_[test_index]
                pred = self.estimators_[split].predict(self.X_[test_index, ])     
                res = target - pred
            pred = self.predict(X, split)
            rv = pred + choice(res, len(pred))
        return rv

    def get_targets(
        self,
        binarize = False
    ):
        targets = []
        for train_index, test_index in self.cv_indexes_:
            target = self.y_[test_index]
            if is_classifier(self.estimator):
                if binarize is True:
                    target = self.label_binarizer_.transform(target)
                    if target.shape[1] == 1:
                        target = np.append(1 - target, target, axis=1)
            targets.append(target)
        return targets
        
    def get_features(
        self
    ):
        features = []
        for train_index, test_index in self.cv_indexes_:
            feature = self.X_[test_index]
            features.append(feature)
        return features

    def get_preds(
        self,
        response_method = "predict"
    ):
        preds = []
        for (train_index, test_index), estimator in zip(
            self.cv_indexes_, self.estimators_):
            predict_func = getattr(estimator, response_method)
            pred = predict_func(self.X_[test_index,])
            preds.append(pred)
        return preds


In [41]:
from sklearn import datasets
from sklearn.linear_model import Lasso
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold, RepeatedKFold
X, y = datasets.load_diabetes(return_X_y=True)
X, y = X[1:150,], y[1:150]
lasso = Lasso(
    random_state=0, 
    max_iter=10000)
alphas = np.logspace(-4, -0.5, 30)
tuned_parameters = [{"alpha": alphas}]
n_folds = 5
rgs = GridSearchCV(lasso, tuned_parameters, cv=n_folds)
rkf = RepeatedKFold(n_splits = 5, n_repeats = 1)
learner = CrossEstimator(rgs, rkf)
_ = learner.fit(X, y)

In [42]:
learner.get_preds()

[array([105.79438737, 166.89914515, 160.7622515 , 233.14868534,
        148.13670405, 104.31745153,  88.86285186, 116.00571407,
        138.07303455, 224.91564138, 100.07931543, 132.74632262,
        116.69114216, 157.59565205, 167.35221425,  79.35270248,
        167.51406377, 160.61056693,  87.24828536, 127.90784069,
        196.68020736, 231.45769718, 158.70391687, 158.9382827 ,
        107.53584674, 224.03842119, 166.13255176, 184.04622765,
        110.65654149, 193.19733546]),
 array([168.75543633, 107.42744694,  77.95286186, 223.42865644,
        109.99784727, 166.53045141,  80.83213532, 141.87535743,
        116.94103021, 162.21507665, 138.35757301,  77.39372089,
         59.0860034 , 104.82095552, 188.35328669, 145.32982145,
        116.32216709, 155.6789585 , 185.17072341,  94.08171873,
        121.59908272, 113.5184474 , 148.97302591, 227.71350469,
        145.8388188 , 302.51979902, 162.78661719, 204.81408445,
        118.24188608, 220.71782872]),
 array([126.20219777, 116.65

In [43]:
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold, RepeatedKFold
rkf = RepeatedKFold(n_splits = 5, n_repeats = 1)
iris = datasets.load_iris()
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
svc = svm.SVC(probability = True)
clf = GridSearchCV(svc, parameters)
learner = CrossEstimator(clf, rkf)
_ = learner.fit(iris.data, iris.target)

In [45]:
learner.get_preds("predict_proba")

[array([[9.66867609e-01, 2.24021021e-02, 1.07302891e-02],
        [9.50992884e-01, 3.33200593e-02, 1.56870571e-02],
        [9.67939988e-01, 2.03115463e-02, 1.17484659e-02],
        [9.61519182e-01, 2.59796335e-02, 1.25011848e-02],
        [9.65254750e-01, 2.33773907e-02, 1.13678598e-02],
        [9.42748254e-01, 4.11448532e-02, 1.61068931e-02],
        [9.79327255e-01, 1.19264506e-02, 8.74629430e-03],
        [9.61063295e-01, 2.72381467e-02, 1.16985585e-02],
        [9.67579545e-01, 2.17432915e-02, 1.06771631e-02],
        [9.46486755e-01, 3.89389316e-02, 1.45743130e-02],
        [9.67623820e-01, 2.09496743e-02, 1.14265057e-02],
        [9.62491127e-01, 2.56487946e-02, 1.18600782e-02],
        [6.74585754e-03, 9.57691419e-01, 3.55627236e-02],
        [4.45936241e-03, 9.71683527e-01, 2.38571103e-02],
        [1.29501807e-02, 8.06904149e-01, 1.80145670e-01],
        [1.57486720e-02, 9.78036309e-01, 6.21501923e-03],
        [2.04013504e-02, 3.53987221e-01, 6.25611428e-01],
        [8.446

In [329]:
class CIT():
    def __init__(
        self, 
        learner,
        remover,
        loss_func = None
    ):
        self.learner = learner
        self.remover = remover
        self.loss_func = loss_func

    def infer(
        self
    ):
        learner_targets = learner.reply()
        learner_preds = learner.predict()
        

        

SyntaxError: incomplete input (2282326014.py, line 21)

In [16]:
learner.sample()

[array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 1, 2, 2, 2]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 1, 1, 2, 1,
        1, 2, 2, 2, 2, 2, 2, 2]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1,
        1, 0, 1, 2, 2, 2, 2, 2]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])]

In [390]:
learner.quantify()

[array([0.03047872, 0.03371168, 0.04622066, 0.04113181, 0.0199701 ,
        0.05687095, 0.04799187, 0.03460839, 0.05022425, 0.03405066,
        0.17229335, 0.25204089, 0.08572921, 0.34540535, 0.0699074 ,
        0.06635565, 0.03908081, 0.05198329, 0.05938119, 0.00350774,
        0.01521838, 0.04411749, 0.09087275, 0.01507744, 0.21357253,
        0.02099946, 0.09016782, 0.11047597, 0.02437719, 0.1494673 ]),
 array([0.03147789, 0.04726225, 0.03124575, 0.03744061, 0.08696101,
        0.0271521 , 0.02547536, 0.04119737, 0.03261583, 0.05335157,
        0.05107176, 0.15538328, 0.03916825, 0.05364395, 0.18464269,
        0.753597  , 0.02594374, 0.0583422 , 0.08509803, 0.0285936 ,
        0.05989299, 0.00355041, 0.01405706, 0.13744988, 0.02667314,
        0.5064949 , 0.01468881, 0.01553904, 0.00971281, 0.04256084]),
 array([0.0543839 , 0.02965617, 0.04939194, 0.02346165, 0.07150733,
        0.05849003, 0.05560787, 0.02689555, 0.06556507, 0.03648379,
        0.18724785, 0.11322675, 0.04016516, 

In [234]:
np.apply_along_axis(
                lambda x: np.argmax(np.bincount(x)),
                axis=0,
                arr=learner.predict(iris.data),
            )

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [157]:
learner.get_losses(loss_func = "log_loss")

[array([0.05645564, 0.05238995, 0.03368536, 0.05384512, 0.05161804,
        0.04756854, 0.02234061, 0.02452637, 0.03387893, 0.07033476,
        0.07788535, 0.04866863, 0.1046384 , 0.06914687, 0.04841894,
        0.04495078, 0.07118952, 0.02451234, 0.06079586, 0.03506321,
        0.02836699, 0.08203701, 0.06065754, 0.04027421, 0.04471997,
        0.02299849, 0.00648436, 0.01445789, 0.00856427, 0.14839065,
        0.02044827, 0.02350801, 0.00827815, 0.01438462, 0.0348732 ,
        0.0050696 , 0.0394479 , 0.00631385, 0.02726235, 0.00327092,
        0.00959646, 0.08875323, 0.01586876, 0.00797441, 0.0354622 ,
        0.02767817, 0.01763332, 0.01074017, 0.12976665, 0.02277979]),
 array([0.02837371, 0.03534768, 0.03498755, 0.03756079, 0.03942477,
        0.06382299, 0.02598236, 0.01590016, 0.03017132, 0.02642951,
        0.01289047, 0.08525378, 0.07070279, 0.06377964, 0.06635004,
        0.01645138, 0.02134683, 0.04700615, 0.03857348, 0.02531727,
        0.09849957, 0.03774065, 0.0549704 , 0.