# Installing packages

This cell is needed to (re)install packages that are needed for the successful completion of the notebook. Currently only works with GPU backend.

In [None]:
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
try:
  import asbe
except:
  !pip install --upgrade --force-reinstall git+https://github.com/puhazoli/asbe.git
  !pip install pytorch-lightning
  !pip install scikit-uplift
  !pip install pylift
  !pip install causeinfer
  !pip install matplotlib==3.7.1

# Loading packages and functions and creating classes

Also, we are defining here functions such as the weighted loss, or classes for DGPs and the base neural network

In [None]:
from sklearn.base import BaseEstimator

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import pytorch_lightning as pl

from typing import Union, Callable, Optional, Tuple, List, Iterator, Any
from copy import deepcopy
from dataclasses import dataclass, field

import pandas as pd
from scipy.optimize import linear_sum_assignment


from asbe.base import *
from asbe.models import *
from asbe.helper import *
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity

from scipy.stats import rankdata

import pickle
import io

In [None]:
from psutil import Process

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
def weighted_loss(output, target, t, weights):
    """Function that returns the weighted meas squared loss"""
    return torch.mean(weights * (output - target)**2)

def pdist2sq(x,y):
    x2 = torch.sum(x ** 2, dim=1, keepdims=True)
    y2 = torch.sum(y ** 2, dim=1, keepdims=True)
    dist = x2 + torch.transpose(y2, 1, 0) - 2. * torch.matmul(x, torch.transpose(y, 1, 0))
    return dist

In [None]:
class IHDP(Dataset):
    def __init__(self, path, subset=False):
        self.data = pd.read_csv(path)
        if subset:
            self.data = self.data[self.data["treatment"]==0].reset_index(drop=True)
        self.y = self.data["y_factual"].to_numpy()
        self.ite = np.where(self.data["treatment"] == 1, self.data["y_factual"] - self.data["y_cfactual"], self.data["y_cfactual"] - self.data["y_factual"])

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return [torch.from_numpy(self.data.iloc[idx,5:].\
                                      to_numpy().reshape(1, -1)).float(),
               torch.tensor(self.y[idx]).float(),
               self.ite[idx],
               torch.tensor(self.data.loc[idx, "treatment"]).float()
                ]

In [None]:
class ASBEDATA(Dataset):
    """Class that supplies data to the neural network in X, y, ite, t format

    Arguments:
    - ds : dictionary, dataset
    - training : boolean, if we are in training mode
    - pool: boolea, if we are in pool model (if both training and pool are false, than it gives test data"""
    def __init__(self, ds, training = True, pool = False, normalize=False):
       self.ds = ds
       self.training = training
       self.pool = pool
       self.normalize = normalize
       if self.normalize:
          self.means = np.mean(self.ds["X_training"], axis=0)
          self.stds = np.std(self.ds["X_training"], axis=0)

    def __len__(self):
        if self.training:
            return self.ds["X_training"].shape[0]
        else:
            return self.ds["X_test"].shape[0]

    def __getitem__(self, idx):
        fstring = "training" if self.training else "test"
        if self.pool:
            fstring = "pool"
        if self.normalize:
          X = self.ds[f"X_{fstring}"][idx,:].astype(float)
          X = (X - self.means) / self.stds
          X = torch.tensor(X)
        else:
          X = torch.tensor(self.ds[f"X_{fstring}"][idx,:])
        out = [X.float(),
                torch.tensor(self.ds[f"y_{fstring}"][idx]).float(),
                torch.tensor(self.ds[f"ite_{fstring}"][idx]).float(),
                torch.tensor(self.ds[f"t_{fstring}"][idx]).float()]
        return out

In [None]:
class DDALIPM(BaseAcquisitionFunction):
  """Custom class to test out ddal
  """

  def treat_v_control(self, model, dataset, method="dist_to_labeled", mmd=False, selection_count=0):
    t = torch.Tensor(dataset["t_training"])
    X = torch.Tensor(dataset["X_training"])
    X_pool = torch.Tensor(dataset["X_pool"])
    # t_random_pool_prob = torch.empty(X_pool.size(0),1).uniform_(0, 1)
    t_pool = torch.Tensor(dataset["t_pool"])
    zero_treatments = t.eq(0)
    one_treatments = t.eq(1)
    zero_pool = t_pool.eq(0)
    one_pool = t_pool.eq(1)
    # Get predictions and latent space representations
    model_pred =  model.model(X, t)
    y_pred, t_pred = model_pred[0].detach(), model_pred[1].detach()
    y_pool_pred, t_pool_pred = model.model(X_pool, t_pool)
    zero_predicted_pool = t_pool_pred.le(.5).squeeze()
    one_predicted_pool = t_pool_pred.ge(.500001).squeeze()
    phi = model.model(X, t, return_phi=True).detach()
    phi_pool = model.model(X_pool, t_pool, return_phi=True).detach()
    prop_score = float(torch.sum(t)/t.size(0))
    distances = torch.zeros(t_pool.size())
    if method == "dist_to_labeled":
      if mmd:
        mmmd_distances = np.zeros(t_pool.size())
        base_loss = model.model.mmdsq_loss(phi, t, t_pred)
        for ix in range(phi_pool.size(0)):
          phi_to_add = phi_pool[ix,:].view(1, phi_pool.size(1))
          phi_new = torch.cat(
              (phi, phi_to_add), 0)
          dist_change = 0
          zero_multiplier = t_pool_pred[ix]
          one_multiplier = 1 - t_pool_pred[ix]
          for cf in [0,1]:
            t_updated = torch.cat((t, torch.tensor([cf])), 0)
            t_pred_new =  t_pool_pred[ix].view(1,1)
            t_pred_updated = torch.cat((t_pred,t_pred_new),0)
            if cf == 0:
              dist_change += zero_multiplier * (base_loss - model.model.mmdsq_loss(phi_new, t_updated, t_pred_updated))
            if cf == 1:
              dist_change += one_multiplier * (base_loss -  model.model.mmdsq_loss(phi_new, t_updated, t_pred_updated))
          mmmd_distances[ix] = dist_change
        return mmmd_distances
      else:
        dist_to_c = torch.mean(pdist2sq(phi_pool[zero_predicted_pool], phi[zero_treatments]), axis=1)
        dist_to_t = torch.mean(pdist2sq(phi_pool[one_predicted_pool], phi[one_treatments]), axis=1)
    if method == "dist_to_counter":
      if mmd:
        mmmd_distances = np.zeros(t_pool.size())
        base_zero_loss = model.model.mmdsq_loss(phi[zero_treatments], t[zero_treatments], t_pred[zero_treatments])
        base_one_loss = model.model.mmdsq_loss(phi[one_treatments], t[one_treatments], t_pred[one_treatments])
        for ix in range(phi_pool.size(0)):
          temp_pred = t_pool_pred[ix]
          phi_to_add = phi_pool[ix,:].view(1, phi_pool.size(1))
          phi_new = torch.cat(
              (phi, phi_to_add), 0)
          # if temp_pred < .5:
          #   cf = 1
          # else:
          #   cf = 0
          for cf in [0, 1]:
            t_updated = torch.cat((t, torch.tensor([cf])), 0)
            t_pred_new =  t_pool_pred[ix].view(1,1)
            t_pred_updated = torch.cat((t_pred,t_pred_new),0)
            if cf == 0:
              dist_change = (1 - temp_pred) * (base_zero_loss - model.model.mmdsq_loss(phi_new[t_updated.eq(cf)],
                                                                    t_updated[t_updated.eq(cf)],
                                                                    t_pred_updated[t_updated.eq(cf)]))
            if cf == 1:
              dist_change += temp_pred * (base_one_loss -  model.model.mmdsq_loss(phi_new[t_updated.eq(cf)],
                                                                    t_updated[t_updated.eq(cf)],
                                                                    t_pred_updated[t_updated.eq(cf)]))
          mmmd_distances[ix] = dist_change / 2
        return mmmd_distances
      else:
        dist_to_c = -1 * torch.mean(pdist2sq(phi_pool[zero_predicted_pool], phi[one_treatments]), axis=1)
        dist_to_t = -1 * torch.mean(pdist2sq(phi_pool[one_predicted_pool], phi[zero_treatments]), axis=1)
    if method == "dist_to_pool":
      dist_to_c = -1 * torch.mean(pdist2sq(phi_pool[zero_predicted_pool], phi_pool[zero_pool]), axis=1)
      dist_to_t = -1 * torch.mean(pdist2sq(phi_pool[one_predicted_pool], phi_pool[one_pool]), axis=1)
    if method == "dist_to_pool_counter":
      dist_to_c = -1 * torch.mean(pdist2sq(phi_pool[zero_predicted_pool], phi_pool[one_pool]), axis=1)
      dist_to_t = -1 * torch.mean(pdist2sq(phi_pool[one_predicted_pool], phi_pool[zero_pool]), axis=1)
    if method == "dist_to_selected":
      if selection_count >= 1:
        print(f"Number of selections: {selection_count}")
        distances = torch.mean(pdist2sq(phi_pool, phi[-selection_count:]), axis=1)
        return distances.detach().numpy()
      else:
        dist_to_c = torch.ones(zero_predicted_pool.sum())
        dist_to_t = torch.ones(one_predicted_pool.sum())
    distances[zero_predicted_pool] = dist_to_c
    distances[one_predicted_pool] = dist_to_t
    distances = distances.detach().numpy()
    # _, t_pred = model.model(X, t)
    # _, t_pred_pool = model.model(X_pool, t_pool)
    # Phic, Phit = phi[zero_treatments], phi[one_treatments]
    # dist = model.model.mmdsq_loss(phi, t, t_pred)
    # ipms = np.zeros(dataset["X_pool"].shape[0])
    # for ix in range(dataset["X_pool"].shape[0]):
    #   phi_to_add = phi_pool[ix,:].resize(1, phi_pool.size(1))
    #   phi_new = torch.cat(
    #       (phi, phi_to_add), 0)
    #   t_new = t_pool[ix].resize(1)
    #   t_updated = torch.cat((t, t_new), 0)
    #   t_pred_new =  t_pred_pool[ix].resize(1,1)
    #   t_pred_updated = torch.cat((t_pred,t_pred_new),0)
    #   dist_change = model.model.mmdsq_loss(phi_new, t_updated, t_pred_updated)
    #   #tempT = model.model.mmdsq_loss(phi_new, torch.cat((t, torch.Tensor([1])), 0), torch.cat((t_pred, t_pred_new),0))
    #   ipms[ix] = dist - dist_change
    return distances

  def train_v_pool(self, model, dataset):
    # t = torch.Tensor(dataset["t_training"])
    # X = torch.Tensor(dataset["X_training"])
    X_pool = torch.Tensor(dataset["X_pool"])
    # t_random_pool_prob = torch.empty(X_pool.size(0),1).uniform_(0, 1)
    # t_pool = torch.bernoulli(t_random_pool_prob)
    t_pool = torch.Tensor(dataset["t_pool"])
    # zero_treatments = t.eq(0)
    # one_treatments = t.eq(1)
    # phi = model.model(X, t, return_phi=True)
    phi_pool = model.model(X_pool, t_pool, return_phi=True)
    dist = -1 * torch.mean(pdist2sq(phi_pool, phi_pool), axis=1).detach().numpy()
    # ipms = np.zeros(dataset["X_pool"].shape[0])
    # num_train = dataset["X_training"].shape[0]
    # X_all = torch.Tensor(np.concatenate([dataset["X_training"], dataset["X_pool"]]))
    # t_trainpool = torch.Tensor(np.concatenate([np.ones(dataset["X_training"].shape[0]),
    #                                np.zeros(dataset["t_pool"].shape[0])]))
    # zero_treatments = t_trainpool.eq(0)
    # one_treatments = t_trainpool.eq(1)
    # phi = model.model(X_all, t_trainpool, return_phi=True)
    # Phic, Phit = phi[zero_treatments], phi[one_treatments]
    # _, t_pred = model.model(X_all, t_trainpool)
    # dist = model.model.mmdsq_loss(phi, t_trainpool, t_pred)
    # for ix in range(dataset["X_pool"].shape[0]):
    #   t_trainpool[(num_train + ix)] = 1
    #   temp_ipm = model.model.mmdsq_loss(phi, t_trainpool, t_pred)
    #   ipms[ix] += dist - temp_ipm
    #   t_trainpool[(num_train + ix)] = 0
    return dist

  def uncertainty(self, model, dataset):
    pred = model.predict(X=dataset["X_pool"], return_mean=False)
    unc =  pred.var(1).detach().numpy()
    return unc

  def subsample(self, dataset, number_to_get):
    sub_ix = np.random.randint(0, dataset["X_pool"].shape[0], number_to_get)
    subsampled = deepcopy(dataset)
    for key, value in subsampled.items():
      if "pool" in key:
        if key.startswith("X"):
          subsampled[key] = subsampled[key][sub_ix, :]
        else:
          subsampled[key] = subsampled[key][sub_ix]
    return([sub_ix, subsampled])

  def kde(self, model, dataset):
    X = torch.Tensor(np.concatenate((dataset["X_training"], dataset["X_pool"])))
    t = torch.Tensor(np.concatenate((dataset["t_training"], dataset["t_pool"])))
    phi = model.model(X, t, return_phi=True).detach().numpy()
    kde = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(phi)
    kde_scores = kde.score_samples(phi)[(dataset["X_training"].shape[0]):]
    return kde_scores

  def _normalize(self, metric):
    return (metric + .0001 - np.min(metric + .0001))/np.max(metric + .0001)

  def calculate_metrics(self, model, dataset, **kwargs):
    scores = np.zeros(dataset["X_pool"].shape[0])
    self.N_pool = dataset["X_pool"].shape[0]
    self.d_data = dataset["X_pool"].shape[1]
    if "mode" not in kwargs:
      raise ValueError("No mode has been supplied")
    if type(kwargs["mode"]) is str:
      mode = [kwargs["mode"]]
    else:
      mode = kwargs["mode"]
    if "uncertainty" in mode:
      scores += self._normalize(self.uncertainty(model, dataset))
    if "treat_v_control" in mode:
      method_to_run = kwargs["method"] if "method" in kwargs.keys() else "dist_to_labeled"
      mmd_to_run = kwargs["mmd"] if "mmd" in kwargs.keys() else False
      scores += self._normalize(self.treat_v_control(model, dataset, method=method_to_run, mmd = mmd_to_run))
      #ipms = ipms.detach().numpy()
      # scores += ipms
    if "train_v_pool" in mode:
      scores += self._normalize(self.train_v_pool(model, dataset))
      # scores += ipms / np.linalg.norm(ipms, ord=1)
    if "batch" in mode:
      res = list()
      selected = list()
      mmd_to_run = kwargs["mmd"] if "mmd" in kwargs.keys() else False
      inb_tc = self.treat_v_control(model, dataset, "dist_to_labeled", mmd = mmd_to_run)
      #inb_dp = self.treat_v_control(model, dataset, "dist_to_pool", mmd = mmd_to_run)
      inb_dc = self.treat_v_control(model, dataset, "dist_to_counter", mmd = mmd_to_run)
      #inb_dcp = self.treat_v_control(model, dataset, "dist_to_pool_counter", mmd = mmd_to_run)
      inb_uc = self.uncertainty(model, dataset)
      for sim in range(kwargs["simulations"]):
        ix = np.random.choice(dataset["t_pool"].shape[0], self.no_query)
        temp_res = dict()
        temp_res["sim"] = sim
        temp_res["tc"] = np.mean(inb_tc[ix])
        #temp_res["dp"] = np.mean(inb_dp[ix])
        # temp_res["dcp"] = np.mean(inb_dcp[ix])
        temp_res["dc"] = np.mean(inb_dc[ix])
        temp_res["uc"] = np.mean(inb_uc[ix])
        temp_res["kde"] = np.mean(kde_scores[ix])
        #temp_res["prop_score"] = -1 * (.5 - np.mean(dataset["t_pool"][ix]))
        res.append(temp_res)
        selected.append(ix)
      rdt = pd.DataFrame(res)
      #print(rankdata(temp_res["tc"].ravel()))
      rdt["ranking_score"] = rankdata(
          rdt["tc"]) + rankdata(
              rdt["dc"]) + rankdata(
                  rdt["uc"] + rankdata(rdt["kde"])) #+ rankdata(rdt["uc"]) + rankdata(rdt["prop_score"])
      rdt.reset_index(inplace=True)
      best_group = rdt.loc[rdt["ranking_score"].idxmax(), "sim"]
      scores[selected[best_group]] = 1
    if "accounting" in mode:
      sel = []
      ds = deepcopy(dataset)
      # _, t_pred = model.model(torch.Tensor(ds["X_pool"]), torch.Tensor(ds["t_pool"]))
      # t_pred_np = t_pred.detach().numpy()
      ixs = np.arange(ds["X_pool"].shape[0])
      inb_uc = self.uncertainty(model, ds)
      # kde_sc = self.kde(model, ds)
      for counter in range(self.no_query):
        subsampled_ix, subsampled_data = self.subsample(ds, 1000)
        if counter % 100 == 0:
          print(f"Currently at {counter}")
        # prop_score = np.sum(ds["t_training"]) / ds["t_training"].shape[0]
        # non_matching = ds["t_pool"] != random_draw
        # random_draw = np.random.binomial(1, 0.3)
        # if random_draw == 1:
          # six = np.random.choice(ds["X_pool"].shape[0], 1)[0]
        # else:
        mmd_to_run = kwargs["mmd"] if "mmd" in kwargs.keys() else False
        inb_tc =  -1 * self.treat_v_control(model, subsampled_data, "dist_to_labeled", mmd = mmd_to_run)
        # inb_dp =  self.treat_v_control(model, ds, "dist_to_pool", mmd = mmd_to_run)
        inb_dc =  self.treat_v_control(model, subsampled_data, "dist_to_counter", mmd = mmd_to_run)
        inb_ds =  self.treat_v_control(model, subsampled_data, "dist_to_selected", mmd = mmd_to_run, selection_count = counter)
        # inb_dcp = self.treat_v_control(model, ds, "dist_to_pool_counter", mmd = mmd_to_run)
        r_tc = rankdata(inb_tc)
        #r_tp = rankdata(inb_dp)
        r_dc = rankdata(inb_dc)
        # r_dcp = rankdata(inb_dcp)
        r_unc = rankdata(inb_uc[subsampled_ix])
        # r_kde = rankdata(kde_sc)
        r_ds = rankdata(inb_ds)

        sub_ix = np.argmax(#r_tc +
                        r_dc + r_unc + r_tc + r_ds)
        # six = np.argmax(inb_tc+inb_dp+inb_dc+inb_dcp + inb_uc)
        # six = np.argmax(r_tc + r_tp + r_dc + r_unc + r_dcp)
        #six = np.argmax(r_tc + r_unc + r_dc + r_tp)
        six = subsampled_ix[sub_ix]
        sel.append(ixs[six])
        ds["X_training"] = np.concatenate((ds["X_training"],
                                           ds["X_pool"][six,:].reshape((1, -1))),
                                          axis=0)
        ds["t_training"] = np.concatenate((ds["t_training"], ds["t_pool"][six].ravel()),
                                          axis=0)
        ds["X_pool"] = np.delete(ds["X_pool"], six, axis=0)
        ds["t_pool"] = np.delete(ds["t_pool"], six)
        ixs = np.delete(ixs, six)
        # kde_sc = np.delete(kde_sc, six)
        #t_pred_np = np.delete(t_pred_np, six)
        inb_uc = np.delete(inb_uc, six)
        # inb_uc = np.delete(inb_uc, six)
      scores[sel] = 1
    return scores

In [None]:
class LightTarNetIPM(pl.LightningModule):
    """Lightweight PL implementation of the TARnet architecture

    Arguments:
    - weights : weights according to the calculation in the original paper, based on number of control and treated units
    - num_features : Number of features, so neural network can be instantianeted """
    def __init__(self, alpha, sigma, pt, weights, num_features, conv, binary):
        super().__init__()
        print("Initializing model....")
        self.binary = binary
        self.conv = conv
        self.alpha = alpha
        self.rbf_sigma = sigma
        self.pt = pt
        if conv:
          self.conv1 = nn.Conv2d(1,
                                  32,
                                  kernel_size=7,
                                  stride=(1,1),
                                  padding="same")
          self.conv2 = nn.Conv2d(32,
                                  64,
                                  kernel_size=5,
                                  stride=(1,1),
                                  padding="same")
        fc_features = 64 * 28 * 28 if self.conv else num_features
        self.fc1 = nn.Linear(fc_features, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 25)
        self.t_hidden = nn.Linear(25, 25)
        self.t_pred = nn.Linear(25, 1)
        #self.fc3 = nn.Linear(200, 200)
        # T == 1
        self.t1fc1 = nn.Linear(25, 200)
        self.t1fc2 = nn.Linear(200, 200)
        self.t1fc3 = nn.Linear(200, 1)
        # T == 0
        self.t0fc1 = nn.Linear(25, 200)
        self.t0fc2 = nn.Linear(200, 200)
        self.t0fc3 = nn.Linear(200, 1)
        self.drop_layer = nn.Dropout(p=0.1)
        self.batch_norm = nn.BatchNorm1d(25)
        self.weights = weights

    def K(self, xi, xj, sigma):
        return torch.exp( -0.5* sigma**2 * torch.sum((xi - xj)**2))

    def RBF_K(self, xi, xj):
        return torch.exp(-pdist2sq(xi,xj)/self.rbf_sigma**2)

    def mmdsq(self, Phi, t, t_pred):
        zero_treatments = t.eq(0)
        one_treatments = t.eq(1)
        Phic, Phit = Phi[zero_treatments], Phi[one_treatments]
        tpredc, tpredt = t_pred[zero_treatments], t_pred[one_treatments]
        weightc = torch.Tensor((1-self.pt)/(1.0001-tpredc)).view(-1,1)
        weightt = torch.Tensor(self.pt/(tpredt+0.0001)).view(-1,1)
        m = Phic.size(0)
        n = Phit.size(0)
        if m > 0:
          weightedphic = weightc * Phic
          Kcc = self.RBF_K(weightedphic,weightedphic)
          if m <= 1:
            mmd = 0
          else:
            mmd = 1.0/(m*(m-1.0))*(torch.sum(Kcc))
        else:
          mmd = 0
        if n > 0:
          weightedphit = weightt * Phit
          Ktt = self.RBF_K(weightedphit,weightedphit)
          if n <= 1:
            mmd += 0
          else:
            mmd = mmd + 1.0/(n*(n-1.0))*(torch.sum(Ktt))
        else:
          mmd += 0
        if m > 0 and n > 0:
          Kct = self.RBF_K(weightedphic,weightedphit)
          mmd = mmd - 2.0/(m*n)*torch.sum(Kct)
        else:
          mmd += 0
        return mmd * torch.ones_like(t)

    def mmdsq_loss(self, phi, t, t_pred):
        mmdsq_loss = torch.mean(self.mmdsq(phi, t, t_pred))
        return mmdsq_loss

    def forward(self, x, t, return_phi=False):
        if self.conv:
          x = F.elu(self.conv1(x.view(x.size(0), 1, 28, 28)))
          x = F.elu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))

        t_latent = F.relu(self.t_hidden(x))
        t_hat = (torch.sigmoid(self.t_pred(t_latent))+ 0.001) / 1.002
        #x = F.elu(self.fc3(x))
        if return_phi:
          return x
        zero_treatments = t.eq(0)
        one_treatments = t.eq(1)
        x0 = torch.clone(x)
        x1 = torch.clone(x)
        x0 = x0.type_as(x)
        x1 = x1.type_as(x)

        x0[~zero_treatments, :] = 0
        x1[zero_treatments,: ] = 0
        x0 = F.relu(self.t0fc1(x0))
        x0 = F.relu(self.t0fc2(x0))
        if self.binary:
          x0 = F.sigmoid(self.t0fc3(x0))
          x0 = self.drop_layer(x0)
        else:
          x0 = self.t0fc3(x0)
          x0 = self.drop_layer(x0)

        x1 = F.relu(self.t1fc1(x1))
        x1 = F.relu(self.t1fc2(x1))
        if self.binary:
          x1 = F.sigmoid(self.t1fc3(x1))
          x1 = self.drop_layer(x1)
        else:
          x1 = self.t1fc3(x1)
          x1 = self.drop_layer(x1)
        out = torch.cat((x0,x1), 1)
        out = torch.gather(out, 1, t.long().unsqueeze(-1))
        # if self.conv or self.binary:
        #   out = nn.bernoulli(out)
        return (out, t_hat)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y, ite, t = batch
        y_hat, t_hat = self(x, t)
        zero_treatments = t.eq(0)
        one_treatments = t.eq(1)
        #tpredc, tpredt = t_hat[zero_treatments], t_hat[one_treatments]
        weightc = (1-self.pt)/(1.0001-t_hat)
        weightt = self.pt/(t_hat+0.0001)

        phi = self(x, t, return_phi=True)
        mmd_loss = self.mmdsq_loss(phi, t, t_hat)
        if self.conv or self.binary:
          loss = F.mse_loss(y_hat, y) + self.alpha * mmd_loss
          #return loss0
        else:
          loss0 = torch.mean((1 - t) * torch.square(y - y_hat)*weightc)
          loss1 = torch.mean(t * torch.square(y - y_hat)* weightt)
          loss =  loss0 + loss1 + self.alpha * mmd_loss
        return loss
        #loss = F.mse_loss(y_hat, y.view(y.size(0), -1))
        #self.log("train_loss", loss)


    def test_step(self, batch, batch_idx):
        x, y, ite, t = batch
        t0 = torch.zeros_like(t)
        t1 = torch.ones_like(t)
        y_hat0 = self(x, t0)[0]
        y_hat1 = self(x, t1)[0]
        ite_hat = y_hat1 - y_hat0
        if self.conv:
          loss = F.mse_loss(ite_hat, ite.view(ite.size(0), -1))
        else:
          loss = F.mse_loss(ite_hat, ite.view(ite.size(0), -1))
        self.log('test_loss', loss)

In [None]:
class ASBETARNETIPM(BaseITEEstimator):
  """Class that wraps around the TARnet to work with asbe

  Arguments:
  - two_model : placeholder, currently has no function
  - num_sim : number of passes through the NN when predicting
  - tpu: if TPU or GPU is used - current version only works with GPU
  """
  def __init__(self, two_model, num_sim = 100, conv = False, tpu=False, epochs=300, binary_outcome=False):
      super().__init__()
      self.binary = binary_outcome
      self.conv = conv
      self.num_sim = num_sim
      self.tpu = tpu
      self.num_epochs = epochs

  def prepare_data(self, X_training=None, t_training=None, y_training=None, ps_scores=None):
    return X_training, t_training, y_training, ps_scores

  def fit(self, **kwargs):
      # num_epochs = self.num_epochs if "epochs" in self.__dict__ else 700
      u = (1/ kwargs["t_training"].shape[0])*kwargs["t_training"].sum()
      ws = torch.from_numpy((kwargs["t_training"]/(2*u) + (
                           1-kwargs["t_training"])/(2*(1-u))))
      pt = np.sum(kwargs["t_training"] == 1) / kwargs["X_training"].shape[0]
      self.model = LightTarNetIPM(alpha=1, sigma=1,
                               weights=ws,
                               pt = pt,
                               conv=self.conv,
                               binary=self.binary,
                               num_features=kwargs["X_training"].shape[1])
      if self.tpu:
          self.trainer = pl.Trainer(max_epochs=600, tpu_cores=8)
      else:
          self.trainer = pl.Trainer(max_epochs=self.num_epochs,
                                    accelerator="gpu",
                                    enable_progress_bar=True,
                                    enable_model_summary=False)
          # self.trainer.tune(self.model)
      data = ASBEDATA(ds = self.dataset)
      dl = DataLoader(data, batch_size = 50, num_workers=4)
      self.trainer.fit(self.model, dl)

  def predict(self, **kwargs):
      ret_mean = True if "return_mean" not in kwargs else kwargs["return_mean"]
      pbool = False if "pool" not in kwargs else kwargs["pool"]
      #tr = kwargs["training"] if "training" in kwargs else True
      if "X" not in kwargs:
        preddata = ASBEDATA(self.dataset, training=False, pool=pbool)
        preddata = preddata[:][0]
      else:
        preddata = torch.tensor(kwargs["X"]).float()
      #dl = DataLoader(preddata, batch_size = 1000, num_workers=1)
      n_pred = preddata.shape[0]
      if ret_mean:
          out1 = self.model(preddata, torch.ones(n_pred))[0]
          out0 = self.model(preddata, torch.zeros(n_pred))[0]
          ret = (out1-out0).detach().numpy()
      else:
          #self.model = self.model.train()
          out1 = torch.cat([self.model(preddata, torch.ones(n_pred))[0].detach() for _ in range(self.num_sim)], dim=1)
          out0 = torch.cat([self.model(preddata, torch.zeros(n_pred))[0].detach() for _ in range(self.num_sim)], dim=1)
          ret = out1 - out0
      return ret

In [None]:
class IPAssignment(BaseAssignmentFunction):
  def select_treatment(self, model, dataset, query_idx):
    y_pred, t_pred = model.model(torch.Tensor(dataset["X_pool"][query_idx, :]),
                                 torch.Tensor(dataset["t_pool"][query_idx]))
    ex = 1 - t_pred.detach().numpy()
    out = np.random.binomial(1, ex).squeeze()
    return out

# Running test code for tarnet

In this section we load an IHDP dataset and try out the model with the above defined classes



In [None]:
def quick_sim(N1, N2, seed=1, ihdp=1, run_random=True):
  score = dict()
  N1_perc = (747-N1)/747
  ds = get_ihdp_dict(ihdp, N1_perc, True, 0.1, seed=seed)
  print(f'Propensity before: {np.sum(ds["t_training"]) /ds["t_training"].shape[0]}')
  copied_ds = deepcopy(ds)
  pt = np.sum(copied_ds["t_training"])/copied_ds["t_training"].shape[0]
  # Pool size is 604, Train size is 74
  # try:
  if run_random:
    asl = BaseActiveLearner(estimator = ASBETARNETIPM(two_model=False),
                          acquisition_function=RandomAcquisitionFunction(no_query=N2),
                          assignment_function=BaseAssignmentFunction(),
                          stopping_function = None,
                          dataset=ds)
    asl.estimator.dataset = asl.dataset
    _, random_sel = asl.query(no_query=N2)
    asl.teach(random_sel)
    asl.fit()
    score["random"] = asl.score()
    print(f'Propensity after: {np.sum(asl.dataset["t_training"]) /asl.dataset["t_training"].shape[0]}')
  # DDAL
  # asl.dataset = copied_ds
  asl_ddal = BaseActiveLearner(estimator = ASBETARNETIPM(two_model=False),
                        acquisition_function=DDALIPM(no_query=N2),
                        assignment_function=IPAssignment(),
                        stopping_function = None,
                        dataset=ds)
  asl_ddal.estimator.dataset = asl_ddal.dataset
  asl_ddal.fit()
  # _, acc_sel = asl_ddal.query(no_query=N2, mode="batch", mmd=False, simulations=100)
  _, mmd_sel = asl_ddal.query(no_query=N2, mode="accounting",mmd=False)
  # _, acc_sel_mmd = asl_ddal.query(no_query=N2, mode="batch", mmd=True, simulations=100)
  _, mmd_sel_mmd = asl_ddal.query(no_query=N2, mode="accounting",mmd=True)

  _, unc_sel = asl_ddal.query(no_query=N2, mode="uncertainty")
  asl_ddal.teach(unc_sel)
  asl_ddal.fit()
  score["uncertainty"] = asl_ddal.score()


  # DDAL MMD
  asl_ddal.dataset = deepcopy(copied_ds)
  asl_ddal.estimator.dataset = deepcopy(copied_ds)
  asl_ddal.teach(mmd_sel)
  asl_ddal.fit()
  score["accounting"] = asl_ddal.score()

  asl_ddal.dataset = deepcopy(copied_ds)
  asl_ddal.estimator.dataset = deepcopy(copied_ds)
  asl_ddal.teach(mmd_sel_mmd)
  asl_ddal.fit()
  score["accounting_mmd"] = asl_ddal.score()
  # except:
  #   score = {"random":np.nan,
  #            "batch":np.nan,
  #            "batch_mmd":np.nan,
  #            "accounting":np.nan,
  #            "accounting_mmd":np.nan}
  print(score)
  return(score)

In [None]:
# # random: 4.62, dropout at 0.1 = 4.16
# res = dict()
# for ihdp in [10]:
#   res[ihdp] = dict()
#   for i in range(10):
#     res[ihdp][i] = quick_sim(50, 50, ihdp=ihdp, seed=i,run_random=True)
  # ch = quick_sim(50, 50, ihdp=1, seed=i,run_random=False)
# res_all = list()
# for N1 in [50, 100]:
#   for N2 in [25, 50, 100]:
#     for ihdp in range(9):
#       for seed in range(10):
#         sc = quick_sim(N1, N2, ihdp=ihdp+1, seed=seed)
#         sc["ihdp"] = ihdp+1
#         sc["seed"] = seed
#         sc["N1"] = N1
#         sc["N2"] = N2
#         res_all.append(sc)
#         print(f"N1: {N1}, N2: {N2}, Ihdp: {ihdp}, seed: {seed}, values: {sc}")
#     pd.DataFrame(res_all).to_csv("/drive/MyDrive/Colab Notebooks/data/ihdp_quick_sim.csv")

In [None]:
import pandas as pd
pd.json_normalize([{'random': 13.261134976726014, 'uncertainty': 12.654133973623953, 'accounting': 13.576951213858909, 'accounting_mmd': 14.117787994069714},
{'random': 14.85508058355626, 'uncertainty': 14.286718574039888, 'accounting': 14.387758800415256, 'accounting_mmd': 14.835449686755718},
{'random': 13.754437282442838, 'uncertainty': 13.344558513854864, 'accounting': 12.636939894258756, 'accounting_mmd': 13.34692728040138},
{'random': 13.214758048534557, 'uncertainty': 13.338884998142, 'accounting': 14.036059370054994, 'accounting_mmd': 14.37753368447979},
{'random': 13.889522192485034, 'uncertainty': 14.38458876320927, 'accounting': 12.584737184837856, 'accounting_mmd': 12.959747216785916},
{'random': 12.223091326849154, 'uncertainty': 14.211100248717251, 'accounting': 14.02388522960542, 'accounting_mmd': 13.057153276908823},
{'random': 14.010992729615008, 'uncertainty': 16.488582613143414, 'accounting': 14.615684341811795, 'accounting_mmd': 14.36954815656109},
{'random': 16.08462250117471, 'uncertainty': 16.92655346649779, 'accounting': 15.999569473701493, 'accounting_mmd': 16.104217618220588},
{'random': 13.318922183767329, 'uncertainty': 13.535266352679773, 'accounting': 11.379797683443073, 'accounting_mmd': 13.696285944397179},
{'random': 14.272557554608497, 'uncertainty': 13.726303219788115, 'accounting': 13.497567905528156, 'accounting_mmd': 11.331218013014523}]).mean()

In [None]:
df_unc = pd.json_normalize(uncertainty_d).T.reset_index()

In [None]:
df_unc[["ihdp", "split", "method"]] = df_unc["index"].str.split(".", expand=True)
df_unc.rename(columns={0:"pehe"}, inplace=True)

In [None]:
df = pd.concat([df, df_unc])

In [None]:
df.to_csv("drive/MyDrive/Colab Notebooks/data/ihdp_10_simulations_3_to_6.csv")

In [None]:
df.groupby(["ihdp", "method"])["pehe"].mean().plot(kind="bar")

In [None]:
df

In [None]:
# from google.colab import runtime
# runtime.unassign()

In [None]:
# df = pd.DataFrame(res_all, columns=["random", "ddal","ihdp","seed"])
# df.groupby(["ihdp"])["random", "ddal"].mean()

# Simulating easy DGP
The aim here is to give a high-level overview with different acquisition function and see which they select, we turn off parts of the acquisition function proposed to see their effect and gain some insights

In [None]:
X  = np.linspace(-4, 4, num = 1000)
y0 = np.abs(X) * 3
ite = (X**3 - 5)/10
y1 = y0 + ite
t = np.random.binomial(1, p=.5, size=1000)
y = np.where(t == 1, y1, y0) + np.random.normal(0, 0.01, size=1000)

In [None]:
plt.plot(X, ite, "b+")
plt.title("ITE function");

In [None]:
prob0 = np.array([0.8 if x < 0 else 0.2 for x in X])
prob1 = np.array([0.8 if x > 0 else 0.2 for x in X])
selected0 = np.random.choice(np.arange(X.shape[0]), size = 3, p = prob0 / np.sum(prob0))
prob1[selected0] = 0
selected1 = np.random.choice(np.arange(X.shape[0]), size = 7, p = prob1 / np.sum(prob1))
dat0 = X[selected0]
dat1 = X[selected1]
X_train = np.concatenate((dat0, dat1), axis=0)
t_train = np.concatenate((np.zeros(dat0.shape[0]), np.ones(dat1.shape[0])))
y_train = np.concatenate((y0[selected0], y1[selected1]))

In [None]:
plt.plot(X_train[:3], y_train[:3], "b+")
plt.plot(X_train[3:], y_train[3:], "r*")
plt.title("Training data");

In [None]:
simds = {"X_training" : X_train.reshape((-1, 1)),
         "y_training" : y_train,
         "t_training" : t_train,
        "ite_training" : ite[np.concatenate([selected0, selected1])],
         "X_pool" : np.delete(X, np.concatenate([selected0, selected1])).reshape((-1, 1)),
         "t_pool" : np.delete(t, np.concatenate([selected0, selected1])),
         "y_pool" : np.delete(y, np.concatenate([selected0, selected1])),
         "ite_pool": np.delete(ite, np.concatenate([selected0, selected1]))}

In [None]:
NO_QUERY = 10
asl = BaseActiveLearner(estimator = ASBETARNETIPM(two_model=False, epochs=50),
                        acquisition_function=DDALIPM(no_query=NO_QUERY),
                        assignment_function=IPAssignment(),
                        stopping_function = None,
                        dataset=simds,
                        no_query=NO_QUERY)
asl.estimator.dataset = asl.dataset
asl.fit()

In [None]:
ranking_dist_to_l = rankdata(asl.acquisition_function.treat_v_control(asl.estimator, asl.dataset,
                                                                      "dist_to_labeled", mmd = False))
plt.plot(X_train, np.zeros(10) - 1, "b+")
plt.plot(asl.dataset["X_pool"], ranking_dist_to_l)
plt.title("Ranking of distance to labeled units measurement")
plt.xlabel("X (crosses represent the 10 training data points)")
plt.ylabel("Ranking score")
plt.savefig("dist_to_labeled.pdf")
plt.show()

In [None]:
kde_scores = asl.acquisition_function.kde(asl.estimator, asl.dataset)
ranking = (
    rankdata(asl.acquisition_function.treat_v_control(asl.estimator, asl.dataset, "dist_to_counter", mmd = False)) +
    rankdata(asl.acquisition_function.treat_v_control(asl.estimator, asl.dataset, "dist_to_selected", mmd = False)) +
    rankdata(kde_scores) +
    rankdata(asl.acquisition_function.uncertainty(asl.estimator, asl.dataset))
)
# sns.lineplot(x=asl.dataset["X_pool"].ravel(),y=ranking).set(
# title ="Ranking of units by DDAL after training on initial sample,\nhighest score is selected for query",
# xlabel ="X for pool units",
# ylabel ="Ranking score")

# plt.savefig("ranking.pdf")
# plt.show()

# plt.plot(asl.dataset["X_training"], asl.dataset["y_training"], "*")
# plt.show()
selection = np.array([np.argmax(ranking)])
asl.teach(selection)
ranking_new = (
    rankdata(asl.acquisition_function.treat_v_control(asl.estimator, asl.dataset,
                                                      "dist_to_selected",
                                                      mmd = False,selection_count=1))
)
sns.lineplot(x=asl.dataset["X_pool"].ravel(),y=ranking_new)

In [None]:
ranking_new

In [None]:
_, unc_selection = asl.query(mode="uncertainty", no_query=NO_QUERY)
_, tc_selection  = asl.query(mode="treat_v_control", no_query=NO_QUERY)
_, tp_selection = asl.query(mode="train_v_pool", no_query=NO_QUERY)
_, all_selection = asl.query(mode=["uncertainty", "treat_v_control", "train_v_pool"], no_query=NO_QUERY)
# _, changing_selection = asl.query(mode=["batch"],mmd=True,simulations=100, no_query=NO_QUERY)
_, rank_selection = asl.query(mode=["accounting"],mmd=True, no_query=NO_QUERY)

In [None]:
plt.plot(X_train[:3], y_train[:3], "b+")
plt.plot(X_train[3:], y_train[3:], "r*")
# #plt.plot(simds["X_pool"][unc_selection, :], simds["y_pool"][unc_selection], "g*")
# plt.title("Training data and selected units");

# plt.scatter(simds["X_pool"][tc_selection, :], simds["y_pool"][tc_selection], color="purple", marker=">")
# plt.scatter(simds["X_pool"][tp_selection, :], simds["y_pool"][tp_selection], color="gray", marker=".")
# plt.scatter(simds["X_pool"][changing_selection, :], simds["y_pool"][changing_selection], color="k", marker="o")
# plt.scatter(simds["X_pool"][rank_selection, :], np.zeros(rank_selection.shape[0]), color="k", marker="*");
# # cycle through

In [None]:
pool_mean_pred = asl.estimator.predict(X=simds["X_pool"], pool=True).ravel()
pool_mean_pred_sd = asl.estimator.predict(X=simds["X_pool"], pool=True, return_mean=False).std(1).detach().numpy()

In [None]:
# # Plot for selection mechanisms
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, sharex=True, sharey=False)
ax1.plot(X_train[:3], y_train[:3], "bo");
ax1.plot(X_train[3:], y_train[3:], "b+");
ax1.plot(X, ite, color="gray", alpha=0.3);
ax1.set_title("Training data");
ax2.scatter(simds["X_pool"][unc_selection, :], np.arange(NO_QUERY), color="purple", marker=">");
ax2.scatter(simds["X_pool"].ravel(), pool_mean_pred_sd, color="black");
ax2.set_title("Uncertainty selection (black points are the predicted uncertainties)");
ax3.scatter(simds["X_pool"][tc_selection, :], np.arange(NO_QUERY), color="purple", marker=">");
ax3.scatter(simds["X_pool"][tp_selection, :], np.arange(NO_QUERY), color="green", marker="<");
ax3.set_title("Treated vs. control selection (purple) and train vs. pool (green)");
# ax4.scatter(simds["X_pool"][changing_selection, :], np.arange(NO_QUERY), color="red", marker="*");
# ax4.set_title("Changing");
ax4.scatter(simds["X_pool"][rank_selection, :], np.arange(NO_QUERY), color="k", marker="*");
ax4.set_title("Bathch-aware DDAL (y-axis is only for better visualization)");
plt.tight_layout();
fig.suptitle('Different acquisition strategies (gray is default ITE function)');
fig.subplots_adjust(top=0.88);
# plt.savefig("acq_new.pdf")

In [None]:
import seaborn as sns

In [None]:
sns.histplot(pd.DataFrame({"Uncertainty sampling" : simds["X_pool"].ravel()[unc_selection].ravel(),
              "Treatment vs control sampling": simds["X_pool"][tc_selection, :].ravel(),
              "Training vs pool samples": simds["X_pool"][tp_selection, :].ravel(),
                          "Batch-aware DDAL sampling": simds["X_pool"][rank_selection, :].ravel()
              }),multiple="layer", shrink=.8
).set(xlim=(-4, 4), title='Selection of pool units of different acquisition functions',
      xlabel="X", ylabel="Number of selections");
plt.savefig("histograms.pdf")

In [None]:
data = [1.5]*7 + [2.5]*2 + [3.5]*8 + [4.5]*3 + [5.5]*1 + [6.5]*8
sns.set_style('whitegrid')
sns.kdeplot(np.array(data), bw=0.5)

# Simulation study

In [None]:
SIMULATION = [0, 1, 2, 3]
num_query = 50
res_ihdp = {}
sim_settings = {0:  "dist_to_labeled",
                1 : "dist_to_counter",
                2:  "dist_to_pool",
                3:  "uncertainty"}
if (len(SIMULATION)) == 1:
    fstr = sim_settings[SIMULATION]
else:
    fstr = "_".join([sim_settings[x] for x in SIMULATION])
save_loc = f"drive/MyDrive/Colab Notebooks/data/no_query_{fstr}_{num_query}_fixed_assignment.csv"
for ihdp in range(10):
    ds = get_ihdp_dict(ihdp+1, 0.9, True, 0.1)
    asl = BaseActiveLearner(estimator = ASBETARNETIPM(two_model=False),
                            acquisition_function=RandomAcquisitionFunction(),
                            assignment_function=IPAssignment(),
                            stopping_function = None,
                            dataset=ds)
    asl.estimator.dataset = asl.dataset
    asl.fit()
    asl.estimator.trainer.save_checkpoint("first_fit.ckpt")
    data_to_restore = deepcopy(asl.dataset)
    res = {}
    for i in range(10):
        U = (1/ asl.dataset["t_training"].shape[0])*asl.dataset["t_training"].sum()
        ws = torch.from_numpy((data_to_restore["t_training"]/(2*U) + (
                              1-data_to_restore["t_training"])/(2*(1-U))))
        pt = np.sum(asl.dataset["t_training"] == 1) / asl.dataset["X_training"].shape[0]
        asl.estimator.model.load_from_checkpoint(alpha =1,
                                                 sigma=1,
                                                 pt=pt,
                                                 weights = ws,
                                                 num_features = data_to_restore["X_training"].shape[1],
                                                 checkpoint_path="first_fit.ckpt")
        # Make query
        tt = torch.Tensor(asl.dataset["t_training"])
        phi = asl.estimator.model(torch.Tensor(asl.dataset["X_training"]),
                                 tt, return_phi = True)
        X_new, ix = asl.query(no_query=num_query)
        sc_old = asl.score()
        ddal = DDALIPM()
        if 0 in SIMULATION:
          # inb_dcp = self.treat_v_control(model, ds, "dist_to_pool_counter", mmd = mmd_to_run)
          avg_labeled_dist = np.mean(ddal.treat_v_control(asl.estimator,
                                                          data_to_restore,
                                                          "dist_to_labeled", mmd = False))
        if 1 in SIMULATION:
          avg_counter_dist = np.mean(ddal.treat_v_control(asl.estimator,
                                                          data_to_restore,
                                                          "dist_to_counter", mmd = False))
        if 2 in SIMULATION:
          avg_pool_dist = np.mean(ddal.treat_v_control(asl.estimator,
                                                       data_to_restore,
                                                       "dist_to_pool", mmd = False))
        if 3 in SIMULATION:
            X_to_pred = asl.estimator.dataset["X_pool"][ix,:]
            pred = asl.estimator.predict(X=X_to_pred,return_mean=False)
            avg_uncertainty = torch.mean(pred.var(1)).detach().numpy()
        asl.teach(ix)
        asl.fit()
        sc_new = asl.score()
        res[i] = {"dist_to_labeled": avg_labeled_dist if 0 in SIMULATION else None,
                  "dist_to_counter": avg_counter_dist if 1 in SIMULATION else None,
                  "dist_to_pool": avg_pool_dist if 2 in SIMULATION else None,
                  "uncertainty": avg_uncertainty if 3 in SIMULATION else None,
                  "pehe_change" : sc_old - sc_new}
        # Restore model and data
        asl.dataset = deepcopy(data_to_restore)
        asl.estimator.dataset = deepcopy(data_to_restore)
        df = pd.DataFrame.from_dict(res).T
        #df.rename({0:"mmd", 1:"loss_change"}, axis=1, inplace=True)
        df["ihdp"] = ihdp
        res_ihdp[ihdp] = df
        r = pd.concat(res_ihdp).reset_index(drop=True)
        r.to_csv(save_loc)

In [None]:
from google.colab import runtime
runtime.unassign()

In [None]:
r[~r.pehe_change.isnull()]

# Simulation study overlap

First, we are concentratin on two different problems:

The first is when the overlap assumption does not hold. Then, we need to obtain data in a way that tries to attain full overlap as soon as possible. We can test this by measuring the overlap ratio (overlap in training data, compared to overlap in the full data) and how different acquisition functions do this.
The second is when the data generating process is not known and the researcher might make wrong assumptions. Some acquisition functions might rely on the linearity of the treatment effect and can be suspectible to problems.
To test this, we run similar simulations than in the previous notebook. We use the same setup and try to understand the relationship between the overlap, DGP mismatch and decrease in PEHE.

We use simulated data for better understanding.
For the overlap data, we simulate different levels of non-overlap.
For the covariate shift part for simulated data, we follow the Uncertainty under covariate shift part of Jesson for IHDP data.

In [None]:
# Make CEMNIST data
def transform_image(im):
   return im.view((1, -1)) / 255
train = pd.read_csv("./sample_data/mnist_train_small.csv", header=None)
test = pd.read_csv("./sample_data/mnist_test.csv")
X_train = train.iloc[:, 1:].to_numpy()
y_train_original = train.iloc[:, 0].to_numpy()
X_test = test.iloc[:, 1:].to_numpy()
y_test_original  = test.iloc[:, 0].to_numpy()

px_train = np.where(y_train_original == 9, 0.5, 0.5/9)
px_test = np.where(y_test_original == 9, 0.5, 0.5/9)

ps_train = np.where(y_train_original == 0, 1/9, 0.5)
ps_test = np.where(y_test_original == 0, 1/9, 0.5)

ps_train[y_train_original == 2] = 1
ps_test[y_test_original == 2] = 1


t_train = np.random.binomial(1, ps_train)
t_test = np.random.binomial(1, ps_test)

y_train = np.where((
    (y_train_original % 2 == 1) & (t_train == 0)) |(
     (y_train_original % 2 == 0) & (t_train == 1)), 1, 0)

y_test = np.where((
   ( y_test_original % 2 == 1) & (t_test == 0)) |(
         (y_test_original % 2 == 0) & (t_test == 1)), 1, 0)

ite = np.where(y_train_original % 2 == 1, -1, 1)
ite_test = np.where(y_test_original % 2 == 1, -1, 1)

X_train, X_pool,t_train, t_pool, y_train, y_pool, ite_train, ite_pool = train_test_split(X_train,t_train, y_train, ite, test_size=0.9)
N_TRAIN = 1000

# Create training data
px = np.where(y_train == 9, 0.5, 0.5/9)
six = np.random.choice(X_train.shape[0], size = N_TRAIN, p=px/np.sum(px))


ds = {"X_training": X_train[six,:],
      "y_training" : y_train[six],
      "t_training" : t_train[six],
      "ite_training" : ite_train[six],
      "X_pool": X_pool,
      "t_pool": t_pool,
      "y_pool": y_pool,
      "ite_pool": ite_pool,
      "X_test":X_test,
      "y_test":y_test,
      "t_test":t_test,
      "ite_test":ite_test}

In [None]:
asl = BaseActiveLearner(estimator = ASBETARNETIPM(two_model=False, conv=True),
                       acquisition_function=DDALIPM(no_query=10),
                       assignment_function=BaseAssignmentFunction(),
                       stopping_function = None,
                       no_query = 100,
                       dataset=ds)
asl.estimator.dataset = asl.dataset
asl.fit()

In [None]:
_, unc_selection = asl.query(mode="uncertainty", no_query=100)
# _, tc_selection  = asl.query(mode="treat_v_control", no_query=10)
# _, tp_selection = asl.query(mode="train_v_pool", no_query=10)
# _, all_selection = asl.query(mode=["uncertainty", "treat_v_control", "train_v_pool"], no_query=10)
_, random_selection = asl.query(mode="random", no_query=100)
_, rank_selection = asl.query(mode="accounting", no_query=100)


#_, batch_selection = asl.query(mode=["batch"], simulations=10)
print(y_pool[unc_selection])
print(y_pool[tc_selection])
print(y_pool[tp_selection])
print(y_pool[all_selection])
print(y_pool[rank_selection])

# **CMF Micro data (main simulation)**  

In [None]:
def create_nonoverlap_data(data, N_train = 50, q=.4, sim_data=True, seed=42, col_to_use = 0):
  qs = np.quantile(data["X"][:, col_to_use], q=q)
  if sim_data:
    x0ind = np.random.choice(np.where(data["X"][:, col_to_use] < qs[0])[0],
                    size = int(N_train/2), replace=False)
    x1ind = np.random.choice(np.where(data["X"][:, col_to_use] > qs[1])[0],
                    size = int(N_train/2), replace=False)
    X_train = data["X"][[x0ind, x1ind], :].reshape((N_train, -1))
    t_train = np.array([np.zeros(int(N_train/2)), np.ones(int(N_train/2))]).ravel()
    y_train = np.array([data["y_0"][x0ind], data["y_1"][x1ind]]).ravel()
    tau_train = data["tau"][np.array([x0ind, x1ind]).ravel()]

  else:
    px0 = np.where(
        data["X"][:, col_to_use] <= qs,
        1000,
        1
    )
    px1 = np.where(
        data["X"][:, col_to_use] > qs,
        1000,
        1
    )
    # print(px0[(data["t"]==1)] / np.sum( px0[(data["t"]==1)]))
    x0ind = np.random.choice(np.arange(data["X"].shape[0])[(data["t"]==0)],
                             int(N_train/2),
                             replace=False,
                             p=px0[(data["t"]==0)] / np.sum(px0[(data["t"]==0)])
                             )
    x1ind = np.random.choice(np.arange(data["X"].shape[0])[(data["t"]==1)],
                            int(N_train/2),
                             replace=False,
                             p=px1[data["t"]==1] / np.sum(px1[(data["t"]==1)])
                             )
    ind_all = np.concatenate([x0ind, x1ind])
    X_train = data["X"][ind_all,:]
    t_train = data["t"][ind_all]
    y_train = data["y"][ind_all]
    tau_train = np.zeros(y_train.shape[0])

  X_pool = np.delete(data["X"], [x0ind, x1ind], axis=0)
  t_pool =  np.delete(data["t"], [x0ind, x1ind], axis=0)
  y_pool = np.delete(data["y"], [x0ind, x1ind])
  y1_pool = np.delete(data["y_1"], [x0ind, x1ind])
  y0_pool = np.delete(data["y_0"], [x0ind, x1ind])
  tau_pool = np.delete(data["tau"], [x0ind, x1ind])
  X_pool, X_test, t_pool, t_test, y_pool, y_test, y1_pool, y1_test, \
  y0_pool, y0_test, tau_pool, tau_test  = train_test_split(X_pool,
  t_pool,
  y_pool,
  y1_pool,
  y0_pool,
  tau_pool,test_size=0.1, random_state = seed)
  ds = {"X_training" : X_train,
         "y_training" : y_train,
         "t_training" : t_train,
         "ite_training" : tau_train,
         "X_pool" : X_pool,
         "t_pool" : t_pool,
         "y_pool" : y_pool,
         "ite_pool": tau_pool,
         "X_test": X_test,
        "y_test" : y_test,
        "t_test" : t_test,
        "ite_test" : tau_test}
  return ds

In [None]:
def create_covariateshift_data(data, N_train = 50, q=.4, sim_data=True, seed=42, col_to_use=0):
  qs = np.quantile(data["X"][:, col_to_use], q=q)
  if sim_data:
    xind = np.random.choice(np.arange(data["X"].shape[0])[data["X"][:, col_to_use] < qs[0]],
                    size = int(N_train), replace=False)
    x0ind = xind[:int(N_train/2)]
    x1ind = xind[int(N_train/2):]
    X_train = data["X"][[x0ind, x1ind], :].reshape((N_train, -1))
    t_train = np.array([np.zeros(int(N_train/2)), np.ones(int(N_train/2))]).ravel()
    y_train = np.array([data["y_0"][x0ind], data["y_1"][x1ind]]).ravel()
    tau_train = data["tau"][np.array([x0ind, x1ind]).ravel()]
  else:
    prob = np.where(data["X"][:, col_to_use] < qs, 1000, 1)
    xind = np.random.choice(np.arange(data["X"].shape[0]),
                    size = int(N_train), replace=False, p = prob/prob.sum())
    x0ind = xind[data["t"][xind] == 0]
    x1ind = xind[data["t"][xind] == 0]
    X_train = data["X"][xind,:]
    t_train = data["t"][xind]
    y_train = data["y"][xind]
    tau_train = np.zeros(y_train.shape[0])

  X_pool = np.delete(data["X"], [x0ind, x1ind], axis=0)
  t_pool =  np.delete(data["t"], [x0ind, x1ind], axis=0)
  y_pool = np.delete(data["y"], [x0ind, x1ind])
  y1_pool = np.delete(data["y_1"], [x0ind, x1ind])
  y0_pool = np.delete(data["y_0"], [x0ind, x1ind])
  tau_pool = np.delete(data["tau"], [x0ind, x1ind])
  X_pool, X_test, t_pool, t_test, y_pool, y_test, y1_pool, y1_test, \
    y0_pool, y0_test, tau_pool, tau_test  = train_test_split(X_pool,
    t_pool,
    y_pool,
    y1_pool,
    y0_pool,
    tau_pool,test_size=0.1, random_state=seed)
  ds = {"X_training" : X_train,
         "y_training" : y_train,
         "t_training" : t_train,
         "ite_training" : tau_train,
         "X_pool" : X_pool,
         "t_pool" : t_pool,
         "y_pool" : y_pool,
         "ite_pool": tau_pool,
         "X_test": X_test,
        "y_test" : y_test,
        "t_test" : t_test,
        "ite_test" : tau_test}
  return ds

In [None]:
from sklift.datasets import fetch_criteo, fetch_hillstrom
from sklift.metrics import qini_auc_score
from sklift.viz import plot_qini_curve

In [None]:
from causeinfer.data import hillstrom, cmf_micro

In [None]:
data_raw = cmf_micro.load_cmf_micro(
    file_path="cmf_micro", format_covariates=True, normalize=True
)

df_full = pd.DataFrame(data_raw["dataset_full"], columns=data_raw["dataset_full_names"])

In [None]:
cmf_data = {"X":data_raw["features"],
               "y": data_raw["response_biz_index"],
               "t": data_raw["treatment"],
               "y_1":np.zeros(data_raw["features"].shape[0]),
               "y_0":np.zeros(data_raw["features"].shape[0]),
               "tau":np.zeros(data_raw["features"].shape[0])}
# cd = create_nonoverlap_data(criteo_data, N_train=5000, sim_data=False)

In [None]:
# ds = create_covariateshift_data(cmf_data, 500, q = .5, sim_data=False)
# asl = BaseActiveLearner(estimator = ASBETARNETIPM(two_model=False,
#                                                   binary_outcome=False,
#                                                   epochs=2,
#                                                   num_sim=5),
#                             acquisition_function=DDALIPM(no_query=100),
#                             assignment_function=BaseAssignmentFunction(),
#                             stopping_function = None,
#                             dataset=ds)
# asl.estimator.dataset = asl.dataset
# asl.fit()
# X_new, ix = asl.query(mode="accounting", no_query=2, mmd = False)

In [None]:
import seaborn as sns

In [None]:
# import matplotlib.pyplot as plt
CTU = 6
sns.set_style("whitegrid")
qs = np.arange(0.0, 1.1, 0.1)
fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(10, 10))
for i, q in enumerate(qs):
    row, col = divmod(i, 3)
    ds = create_covariateshift_data(criteo_data, q=q,N_train=500,
                                sim_data=False, col_to_use = CTU)
    quan = np.quantile(criteo_data["X"][:, CTU], q=[q])
    axes[row, col].hist(ds["X_training"][ds["t_training"] == 0, CTU], density=True, alpha=0.5, label="Control")
    axes[row, col].hist(ds["X_training"][ds["t_training"] == 1, CTU], density=True, alpha=0.5, label="Treatment")
    axes[row, col].axvline(x=quan, color="black", linestyle="--")
    axes[row, col].set_xlim(-2, 7)
    axes[row, col].legend()
    axes[row, col].set_title(f"q = {q:.1f}")
axes[3, 2].hist(ds["X_test"][:, 0], density=True, alpha=0.5)
axes[3, 2].set_title("Density on test data")
plt.tight_layout()
# plt.savefig("covshift_data.pdf")
plt.show()

In [None]:
from pylift.eval import get_scores

In [None]:
# sns.set_style('whitegrid')
# for i in range(1,10):
#   sns.kdeplot(create_nonoverlap_data(
#       criteo_data, 5000,col_to_use = 54, q = i/10, sim_data=False)["X_training"][:, 54],
#             bw_method=0.3);

In [None]:
#NT = 50
#NQ = 20
#DISTORTION = "covshift"
for CTS in [6]:
# CTS = 54
  for DISTORTION in ["overlap"]:
    for LEVEL in [0.2, 0.5, 0.8]:
      for NT in [500]:
        for NQ in [500]:
          FNAME = f'drive/MyDrive/Colab Notebooks/data/{DISTORTION}_{LEVEL}_{NT}_{NQ}_ipm_cmf_col_{CTS}.pkl'
          try:
            with open(FNAME, 'rb') as f:
              res = pickle.load(f)
              if i in res.keys():
                continue
          except:
            res = {}
          # for i, dd in data.items():
          #   if i in res.keys():
          #     continue
          for i in range(10):
            print(f"Running, n1 {NT}, n2 {NQ}, dist: {DISTORTION}, level {LEVEL}")
            # if i in np.concatenate([np.arange(0, 325, 9), np.arange(3, 325, 9), np.arange(8, 325, 9)]).tolist():
            #   print(f"DATA {i}")
            if DISTORTION == "overlap":
              ds = create_nonoverlap_data(criteo_data,
                                          NT,
                                          q = LEVEL,
                                          sim_data=False, seed=i,
                                          col_to_use = CTS)
            elif DISTORTION == "covshift":
              ds = create_covariateshift_data(criteo_data,
                                              NT,
                                              q = LEVEL,
                                              sim_data=False,
                                              seed=i,
                                              col_to_use = CTS)
            asl = BaseActiveLearner(estimator = ASBETARNETIPM(two_model=False, binary_outcome=False,
                                                              epochs=100,
                                                              num_sim=100),
                                        acquisition_function=DDALIPM(no_query=NQ),
                                        assignment_function=BaseAssignmentFunction(),
                                        stopping_function = None,
                                        dataset=ds)
            asl.estimator.dataset = asl.dataset
            asl.fit()
            # asl.estimator.trainer.save_checkpoint("first_fit.ckpt")
            data_to_restore = deepcopy(asl.dataset)
            res[i] = {}
            U = (1/ asl.dataset["t_training"].shape[0])*asl.dataset["t_training"].sum()
            ws = torch.from_numpy((data_to_restore["t_training"]/(2*U) + (
                                    1-data_to_restore["t_training"])/(2*(1-U))))
            preds = asl.predict(asl.dataset["X_test"])
            sc_old = get_scores(asl.dataset["t_test"],
                      asl.dataset["y_test"],
                      preds.ravel(),
                      p=np.zeros(asl.dataset["y_test"].shape[0]) + 0.51,
                        plot_type='aqini')["q1_aqini"]
            # sc_old = qini_auc_score(asl.dataset["y_test"],
            #                         preds.ravel(),
            #                         asl.dataset["t_test"])
            #sc_old = asl.score()
            res[i].update({"original" : sc_old})
            print(f"Selection score - original : {sc_old} ")
            pt = np.sum(asl.dataset["t_training"] == 1) / asl.dataset["X_training"].shape[0]
            for selection_mode in ["accounting-F",
                                    # "uncertainty",
                                    "random",
                                  #  "batch",
                                    # "accounting-T",
                                    ]:
                # Make query
                if selection_mode in ["uncertainty",
                                      "treat_v_control-dist_to_labeled",
                                      "treat_v_control-dist_to_pool",
                                      "treat_v_control-dist_to_counter",
                                      "accounting-T",
                                      "accounting-F"]:
                  if selection_mode.startswith("treat_v_control"):
                    sel_mode, method = selection_mode.split("-")
                    for method in ["dist_to_labeled", "dist_to_pool", "dist_to_counter"]:
                      X_new, ix = asl.query(mode=selection_mode, no_query=NQ, method = method)
                  elif selection_mode.startswith("accounting"):
                    sel_mod, mmd_str = selection_mode.split("-")
                    mmd_boolean = True if mmd_str == "T" else False
                    X_new, ix = asl.query(mode=sel_mod, no_query=NQ, mmd = mmd_boolean)
                    # X_acc = deepcopy(X_new)
                  else:
                    X_new, ix = asl.query(mode=selection_mode, no_query=NQ)
                elif selection_mode == "batch":
                  X_new, ix = asl.query(mode=selection_mode, no_query=NQ, simulations=50)
                elif selection_mode == "all":
                  X_new, ix = asl.query(mode=["uncertainty", "treat_v_control", "train_v_pool"],no_query=NQ)
                elif selection_mode == "random":
                  X_new, ix = asl.query(acquisition_function=RandomAcquisitionFunction(no_query=NQ), no_query=NQ)
                  # X_r = deepcopy(X_new)
                asl.teach(ix)
                asl.estimator.dataset = asl.dataset
                asl.fit()
                preds = asl.predict(asl.dataset["X_test"])
                sc_new = get_scores(asl.dataset["t_test"],
                      asl.dataset["y_test"],
                      preds.ravel(),
                      p=np.zeros(asl.dataset["y_test"].shape[0]) + 0.51,
                        plot_type='aqini')["q1_aqini"]
                res[i].update({selection_mode :sc_new})
                print(f"Selection score - {selection_mode} : {sc_new} ")
              # Restore model and data
                asl.dataset = deepcopy(data_to_restore)
                asl.estimator.dataset = deepcopy(data_to_restore)
                with open(FNAME, 'wb') as f:
                  pickle.dump(res, f)

In [None]:
ds = create_nonoverlap_data(criteo_data, 500,
                                        q = 0.1,
                                        sim_data=False, seed=i,
                                        col_to_use = 54)
asl = BaseActiveLearner(estimator = ASBETARNETIPM(two_model=False, binary_outcome=False,
                                                            epochs=100,
                                                            num_sim=100),
                                      acquisition_function=DDALIPM(no_query=500),
                                      assignment_function=BaseAssignmentFunction(),
                                      stopping_function = None,
                                      dataset=ds)
asl.estimator.dataset = asl.dataset
asl.fit()

In [None]:
X_new, ix = asl.query(mode="accounting", no_query=500, mmd = False)

In [None]:
sns.set_style('whitegrid')
fig, ax =plt.subplots(4,1,sharex=True, sharey=True)
for m in ["training", "ddal", "random", "test"]:
  if m == "training":
    sns.kdeplot(asl.dataset["X_training"][:, 54][asl.dataset["t_training"] == 1],
                bw_method=0.3, label="training data", ax=ax[0]);
    sns.kdeplot(asl.dataset["X_training"][:, 54][asl.dataset["t_training"] == 0],
                bw_method=0.3, label="training data", ax=ax[0]);
  elif m == "ddal":
    sns.kdeplot(X_new[:, 54][asl.dataset["t_pool"][ix] == 1],
                bw_method=0.3, label="DDAL", ax=ax[1]);
    sns.kdeplot(X_new[:, 54][asl.dataset["t_pool"][ix] == 0],
                bw_method=0.3, label="DDAL", ax=ax[1]);
  elif m == "random":
    randix = np.random.randint(0, asl.dataset["X_pool"].shape[0], 500)
    sns.kdeplot(asl.dataset["X_pool"][:, 54][randix][asl.dataset["t_pool"][randix] == 1],
                bw_method=0.3, label="random", ax=ax[2]);
    sns.kdeplot(asl.dataset["X_pool"][:, 54][randix][asl.dataset["t_pool"][randix] == 0],
                bw_method=0.3, label="random", ax=ax[2]);
  else:
    sns.kdeplot(asl.dataset["X_test"][:, 54][asl.dataset["t_test"] == 1],
                bw_method=0.3, label="random", ax=ax[3]);
    sns.kdeplot(asl.dataset["X_test"][:, 54][asl.dataset["t_test"] == 0],
                bw_method=0.3, label="random", ax=ax[3]);
plt.tight_layout()