# ASBE - Automatic Stopping for Batch Experiments

> API details.

In [None]:
#hide
from nbdev import *

In [None]:
%nbdev_default_export core

Cells will be exported to asbe.core,
unless a different module is specified after an export flag: `%nbdev_export special.module`


In [None]:
%nbdev_export
from modAL.models.base import BaseLearner
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from typing import Union, Optional
from copy import deepcopy
import numpy as np

In [None]:
from sklearn.linear_model import LogisticRegression

In [None]:
%nbdev_export
estimator_type = Union[ClassifierMixin, RegressorMixin]
class ASLearner(BaseLearner):
    def __init__(self,
                 estimator: Optional[estimator_type]=None, 
                 query_strategy=None,
                 assignment_fc=None
                ) -> None:
        self.estimator = estimator
        self.query_strategy = query_strategy
        self.assignment_fc = assignment_fc
        
    def teach(self, X, query_X):
        return(self.estimator.fit(X))

In [None]:
%nbdev_export
class ITEEstimator(BaseEstimator):
    """ Class for building a naive estimator for ITE estimation
    """
    def __init__(self,
                 model: estimator_type,
                 X_training,
                 t_training,
                 y_training,
                 X_unlabeled,
                 two_model: bool = False,
                ) -> None:
        self.model = model
        self.X_training = X_training
        self.y_training = y_training
        self.t_training = t_training
        self.two_model  = two_model
        self.X_unlabeled = X_unlabeled
        self.N_training = X_training.shape[0]
    
    def fit(self):
        if self.two_model:
            self.m1 = deepcopy(self.model)
            control_ix = np.where(self.t_training == 0)
            print(self.X_training[control_ix,:].shape)
            self.model.fit(self.X_training[control_ix,:],
                           self.y_training[control_ix])
            self.m1.fit(self.X_training[-control_ix,:],
                        self.y_training[-control_ix])
        else:
            self.model.fit(np.hstack((self.X_training,
                                      self.t_training.reshape((self.N_training, -1)))),
                           self.y_training)
            
    def predict(self):
        if self.two_model:
            self.y1_preds = self.m1.predict_proba(self.X_unlabeled)[:,1]
            self.y0_preds = self.model.predict_proba(self.X_unlabeled)[:,1]
        else:
            N_unlabeled = self.X_unlabeled.shape[0]
            self.y1_preds = self.model.predict_proba(
                                np.hstack((self.X_unlabeled,
                                np.ones(self.X_unlabeled.shape[0]).reshape(-1,1))))[:,1]
            self.y0_preds = self.model.predict_proba(
                np.hstack((self.X_unlabeled,
                           np.zeros(self.X_unlabeled.shape[0]).reshape(-1,1))))[:,1]
            return self.y1_preds - self.y0_preds, self.y1_preds, self.y0_preds

In [None]:
X = np.random.normal(size = 1000).reshape((500,2))
t = np.random.binomial(n = 1, p = 0.5, size = 500)
y = np.random.binomial(n = 1, p = 1/(1+np.exp(X[:, 1]*2)))
X_test = np.random.normal(size = 200).reshape((100,2))

In [None]:
a = ITEEstimator(LogisticRegression(solver="lbfgs"), X, t, y, X_test,two_model=True )

In [None]:
a.fit()

(1, 249, 2)


ValueError: Found array with dim 3. Estimator expected <= 2.

In [None]:
a.predict()

(array([-0.00631163, -0.00690893, -0.00968889, -0.01162548, -0.00966251,
        -0.00356955, -0.0011706 , -0.00709584, -0.00164621, -0.00579412,
        -0.00892574, -0.01123352, -0.00893689, -0.00591877, -0.01031452,
        -0.00443958, -0.01127404, -0.00727093, -0.0110698 , -0.01108541,
        -0.0088907 , -0.00479682, -0.0085442 , -0.00498098, -0.01054354,
        -0.01161135, -0.00848985, -0.00568276, -0.00663026, -0.00981189,
        -0.0091402 , -0.01065809, -0.0113113 , -0.01156236, -0.01138042,
        -0.01153759, -0.01148206, -0.01125006, -0.00844727, -0.01065506,
        -0.00990782, -0.01020665, -0.01143533, -0.00674161, -0.01018128,
        -0.01159302, -0.00729394, -0.00776386, -0.01161916, -0.00325214,
        -0.00720661, -0.0098169 , -0.00726728, -0.00207751, -0.00884736,
        -0.00782709, -0.00256556, -0.00510249, -0.01014151, -0.00944256,
        -0.00509608, -0.00959268, -0.00325168, -0.01156415, -0.0067078 ,
        -0.00922942, -0.00173808, -0.01159593, -0.0