In [None]:
#| default_exp estimators

In [None]:
#| export
#| output: false
from asbe.base import *
#from econml.orf import DMLOrthoForest
from econml.dml import CausalForestDML
# from causalml.inference.nn import CEVAE
# from openbt.openbt import OPENBT
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
import econml

In [None]:
#| export 
class CausalForestEstimator(BaseITEEstimator):
    def fit(self, **kwargs):
        if self.model is None:
            self.model = CausalForestDML()
        self.model.fit(Y=kwargs["y_training"],
                       T=kwargs["t_training"],
                       X=kwargs["X_training"])

    def predict(self, **kwargs):
        preds = self.model.effect_inference(kwargs["X"])
        if "return_mean" in kwargs:
            out = preds.pred
        else:
            out = (preds.pred, preds.var)
        return out

In [None]:
#| export
class OPENBTITEEstimator(BaseITEEstimator):
    """Modified ITE estimator for OPENBT

    The predictions are transposed so the uncertainty sampler can calculate uncertianty easily"""
    def predict(self, **kwargs):
        X = kwargs["X"]
        if self.ps_model is not None:
            ps_scores = self.ps_model.predict_proba(X)
            X = np.hstack((X, ps_scores[:,1].reshape((-1, 1))))
        X0 = np.concatenate((X,
                             np.zeros(X.shape[0]).reshape((-1,1))),axis=1)
        X1 = np.concatenate((X,
                             np.ones(X.shape[0]).reshape((-1,1))),axis=1)
        preds0 = self.model.predict(X0)
        preds1 = self.model.predict(X1)
        if "return_mean" in kwargs:
            if kwargs["return_mean"]:
                out = preds1["mmean"] - preds0["mmean"]
        else:
            out = preds1["mdraws"].T - preds0["mdraws"].T
        if "return_per_cf" in kwargs:
            if kwargs["return_per_cf"]:
                return {"pred1": preds1["mdraws"].T , "pred0":preds0["mdraws"].T}
        return out

In [None]:
#no-export 
class CEVAEEstimator(BaseITEEstimator):
    def fit(self, **kwargs):
        if self.model is None:
            self.model = CEVAE()
        self.model.fit(kwargs["X_training"], 
                       kwargs["t_training"], 
                       y=kwargs["y_training"])
        
    def predict(self, **kwargs):
        return self.model.predict(X = kwargs["X"])

In [None]:
#| export
class GPEstimator(BaseITEEstimator):
    def predict(self, **kwargs):
        if 'return_mean' in kwargs:
            pred0 = self.model.predict(kwargs["X"])
            pred1 = self.m1.predict(kwargs["X"])
            ite = pred1 - pred0
        else:
            draws0 = self.model.sample_y(kwargs["X"], n_samples=100)
            draws1 = self.m1.sample_y(kwargs["X"], n_samples=100)
            ite = draws1 - draws0
        return ite

In [None]:
#no-export, crashes openbt
class BLREstimator(BaseITEEstimator):
    def fit(self, **kwargs):
        with pm.Model() as self.model:
            # https://juanitorduz.github.io/glm_pymc3/
            family = pm.glm.families.Normal()
            data = pm.Data("data", kwargs["X_training"])
            labels = ["x"+str(i) for i in range(kwargs["X_training"].shape[1])]
            glm.GLM(y=kwargs["y_training"], x = data, family=family, labels=labels)
            self.trace = sample(3000, cores=2) 
            
    def predict(self, **kwargs):
        X0 = np.concatenate((kwargs["X"],
                             np.zeros(kwargs["X"].shape[0]).reshape((-1,1))),axis=1)
        X1 = np.concatenate((kwargs["X"],
                             np.ones(kwargs["X"].shape[0]).reshape((-1,1))),axis=1)
        pm.set_data({"data": X1}, model=self.model)
        p1 = pm.sample_posterior_predictive(self.trace, model=self.model)
        pm.set_data({"data": X0}, model=self.model)
        p0 = pm.sample_posterior_predictive(self.trace, model=self.model)
        ite = p1["y"] - p0["y"]
        if 'return_mean' in kwargs:
            out = ite.mean(axis=0)
        else:
            out = ite.T
        return out