In [None]:
#| default_exp estimators

In [None]:
#| export
#| output: false
#| output: false
from asbe.base import *
from econml.dml import CausalForestDML
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
import econml
import sklift
# import causalml
# import pymc as pm 
# import pymc_bart as pmb

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export 
class CausalForestEstimator(BaseITEEstimator):
    def fit(self, **kwargs):
        if self.model is None:
            self.model = CausalForestDML()
        self.model.fit(Y=kwargs["y_training"],
                       T=kwargs["t_training"],
                       X=kwargs["X_training"])

    def predict(self, **kwargs):
        if 'return_counterfactuals' in kwargs:
            raise ValueError("Causal Forest does not support counterfactual predictions out of the box")
        preds = self.model.effect_inference(kwargs["X"])
        if "return_mean" in kwargs:
            out = preds.pred
        else:
            out = (preds.pred, preds.var)
        return out

In [None]:
#| export
class OPENBTITEEstimator(BaseITEEstimator):
    """Modified ITE estimator for OPENBT

    The predictions are transposed so the uncertainty sampler can calculate uncertianty easily"""
    def predict(self, **kwargs):
        X = kwargs["X"]
        if self.ps_model is not None:
            ps_scores = self.ps_model.predict_proba(X)
            X = np.hstack((X, ps_scores[:,1].reshape((-1, 1))))
        X0 = np.concatenate((X,
                             np.zeros(X.shape[0]).reshape((-1,1))),axis=1)
        X1 = np.concatenate((X,
                             np.ones(X.shape[0]).reshape((-1,1))),axis=1)
        preds0 = self.model.predict(X0)
        preds1 = self.model.predict(X1)
        if "return_mean" in kwargs:
            if kwargs["return_mean"]:
                out = preds1["mmean"] - preds0["mmean"]
        else:
            out = preds1["mdraws"].T - preds0["mdraws"].T
        if "return_per_cf" in kwargs:
            if kwargs["return_per_cf"]:
                return {"pred1": preds1["mdraws"].T , "pred0":preds0["mdraws"].T}
        return out

In [None]:
#no-export 
class CEVAEEstimator(BaseITEEstimator):
    def fit(self, **kwargs):
        if self.model is None:
            self.model = CEVAE()
        self.model.fit(kwargs["X_training"], 
                       kwargs["t_training"], 
                       y=kwargs["y_training"])
        
    def predict(self, **kwargs):
        return self.model.predict(X = kwargs["X"])

In [None]:
#| export
class GPEstimator(BaseITEEstimator):
    def predict(self, **kwargs):
        if 'return_mean' in kwargs:
            pred0 = self.model.predict(kwargs["X"])
            pred1 = self.m1.predict(kwargs["X"])
            ite = pred1 - pred0
        else:
            draws0 = self.model.sample_y(kwargs["X"], n_samples=100)
            draws1 = self.m1.sample_y(kwargs["X"], n_samples=100)
            ite = draws1 - draws0
        return ite

In [None]:
#no-export
class BLREstimator(BaseITEEstimator):
    def fit(self, **kwargs):
        with pm.Model() as self.model:
            # https://juanitorduz.github.io/glm_pymc3/
            family = pm.glm.families.Normal()
            data = pm.Data("data", kwargs["X_training"])
            labels = ["x"+str(i) for i in range(kwargs["X_training"].shape[1])]
            glm.GLM(y=kwargs["y_training"], x = data, family=family, labels=labels)
            self.trace = sample(3000, cores=2) 
            
    def predict(self, **kwargs):
        X0 = np.concatenate((kwargs["X"],
                             np.zeros(kwargs["X"].shape[0]).reshape((-1,1))),axis=1)
        X1 = np.concatenate((kwargs["X"],
                             np.ones(kwargs["X"].shape[0]).reshape((-1,1))),axis=1)
        pm.set_data({"data": X1}, model=self.model)
        p1 = pm.sample_posterior_predictive(self.trace, model=self.model)
        pm.set_data({"data": X0}, model=self.model)
        p0 = pm.sample_posterior_predictive(self.trace, model=self.model)
        ite = p1["y"] - p0["y"]
        if 'return_mean' in kwargs:
            out = ite.mean(axis=0)
        else:
            out = ite.T
        return out

In [None]:
#| export
class SKLiftEstimator(BaseITEEstimator):
    def __init__(self, sk_model):
        self.model = sk_model
    
    def fit(self, **kwargs):
        self.model.fit(X=kwargs["X_training"],
                       y=kwargs["y_training"],
                       treatment=kwargs["t_training"])
    
    def predict(self, **kwargs):
        if 'return_counterfactuals' in kwargs:
            raise ValueError("Causal Forest does not support counterfactual predictions out of the box")
        out = self.model.predict(kwargs["X_test"])
        return out

In [None]:
#no-export
class BARTEstimator(BaseITEEstimator):
    def fit(self, **kwargs):
        self.model_bart = pm.Model()
        X = np.concatenate((kwargs["X_training"], kwargs["t_training"].reshape((-1,1))),axis=1)            
        with self.model_bart:
            x_obs = pm.MutableData("X", X)
            self.model_bart.named_vars.pop("X")
            x_obs = pm.MutableData("X", X)
            μ = pmb.BART("μ", X=x_obs, Y=kwargs["y_training"], m=100)
            y_pred = pm.Normal("y_pred", mu=μ, observed=kwargs["y_training"],
                               shape=μ.shape)
            self.trace = pm.sample(random_seed=1005)
            
    def predict(self, **kwargs):
        X0 = np.concatenate((kwargs["X"],
                             np.zeros(kwargs["X"].shape[0]).reshape((-1,1))),axis=1)
        X1 = np.concatenate((kwargs["X"],
                             np.ones(kwargs["X"].shape[0]).reshape((-1,1))),axis=1)
        with self.model_bart:
            pm.set_data({"X": X1})
            p1 = pm.sample_posterior_predictive(self.trace)
            pm.set_data({"X": X0})
            p0 = pm.sample_posterior_predictive(self.trace)
            ite = p1["posterior_predictive"]["y_pred"] - p0["posterior_predictive"]["y_pred"]
            ite = ite.mean(axis=0).to_numpy() #across chains
            if "return_mean" in kwargs:
                if kwargs["return_mean"]:
                    return ite.mean(axis=0)
            else:
                return ite.T