In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard
%matplotlib inline

# Mounting and Installation of Package: Only on Google Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd /content/drive/My\ Drive/Colab\ Notebooks/adversarial_reisz

In [None]:
!python setup.py develop

# Library Imports

In [None]:
import numpy as np
import torch
from pathlib import Path
import torch.nn as nn
import matplotlib.pyplot as plt
import scipy
import scipy.special
import advreisz
from advreisz.deepreisz import DeepReisz

# Data Generation

In [None]:
def nonlin(x):
    return 1.5 * (x[:, 1]**2)
    # return scipy.special.expit(10 * x[:, 1])
    # return np.abs(x[:, 1])
    # return 1.5*scipy.special.expit(10 * x[:, 1]) - 1.5*scipy.special.expit(10 * x[:, 2])

def true_propensity(X):
    return .5 + .3 * nonlin(X)

def true_f(X):
    return X[:, 0] + nonlin(X)

def gen_data(n, p):
    X = np.random.uniform(-1, 1, size=(n, p))
    X[:, 0] = np.random.binomial(1, true_propensity(X))
    y = true_f(X) + np.random.normal(size=(n,))
    return X, y

np.random.seed(123)
n = 10000
p = 3
X, y = gen_data(n, p)
X_test, y_test = gen_data(n, p)

# Moment Definition

### ATE Moment

In [None]:
# Returns the moment for the ATE example, for each sample in x
def moment_fn(x, test_fn):
    if torch.is_tensor(x):
        with torch.no_grad():
            t1 = torch.cat([torch.ones((x.shape[0], 1)).to(device), x[:, 1:]], dim=1)
            t0 = torch.cat([torch.zeros((x.shape[0], 1)).to(device), x[:, 1:]], dim=1)
    else:
        t1 = np.hstack([np.ones((x.shape[0], 1)), x[:, 1:]])
        t0 = np.hstack([np.zeros((x.shape[0], 1)), x[:, 1:]])
    return test_fn(t1) - test_fn(t0)

def true_reisz(x, propensity=true_propensity):
    return (x[:, 0]==1) / propensity(x) - (x[:, 0]==0)/(1 - propensity(x))

### Policy Moment

In [None]:
policy = lambda x: scipy.special.expit(10*x[:, 1]) if not torch.is_tensor(x) else torch.sigmoid(10 * x[:, 1])

def moment_fn(x, test_fn):
    with torch.no_grad():
        if torch.is_tensor(x):
            t1 = torch.cat([torch.ones((x.shape[0], 1)).to(device), x[:, 1:]], dim=1)
            t0 = torch.cat([torch.zeros((x.shape[0], 1)).to(device), x[:, 1:]], dim=1)
        else:
            t1 = np.hstack([np.ones((x.shape[0], 1)), x[:, 1:]])
            t0 = np.hstack([np.zeros((x.shape[0], 1)), x[:, 1:]])
        p1 = policy(x)
    out1 = test_fn(t1)
    out0 = test_fn(t0)
    if len(out1.shape) > 1:
        p1 = p1.reshape(-1, 1)
    return out1 * p1 + out0 * (1 - p1) 

def true_reisz(x, propensity=true_propensity):
    p1 = policy(x)
    prop1 = propensity(x)
    return p1 * (x[:, 0]==1) / prop1 + (1 - p1) * (x[:, 0]==0)/(1 - prop1)

### X-Transformation Moment

In [None]:
trans = lambda x: (1 + x)**(1/2)/2**(1/2)
invtrans = lambda u: 2 * (u**2) - 1
grad_invtrans = lambda u: 4 * u

def moment_fn(x, test_fn):
    with torch.no_grad():
        if torch.is_tensor(x):
            tx = torch.cat([x[:, [0]], trans(x[:, [1]]), x[:, 2:]], dim=1)
        else:
            tx =  np.hstack([x[:, [0]], trans(x[:, [1]]), x[:, 2:]])
    return test_fn(tx) - test_fn(x)

def true_reisz(x, propensity=true_propensity):
    invtx = np.hstack([x[:, [0]], invtrans(x[:, [1]]), x[:, 2:]])
    out0 = np.clip(propensity(invtx) * grad_invtrans(x[:, 1]) / propensity(x), 0, np.inf)
    out1 = np.clip((1 - propensity(invtx)) * grad_invtrans(x[:, 1]) / (1 - propensity(x)), 0, np.inf)
    return out0 * x[:, 0] + out1 * (1 - x[:, 0]) - 1

# Adversarial Reisz Estimator for ATE Moment

In [None]:
drop_prob = 0.2  # dropout prob of dropout layers throughout notebook
n_hidden = 100  # width of hidden layers throughout notebook

# Training params
learner_lr = 1e-4
adversary_lr = 1e-4
learner_l2 = 1e-3
adversary_l2 = 1e-3
n_epochs = 1000
preprocess_epochs = 200 # how many epochs to use to create an approximation to the max objective for earlystopping
earlystop_rounds = 40 # how many epochs to wait for an out-of-sample improvement
store_test_every = 20 # after how many training iterations to store a test function during preprocessing
bs = 100
device = torch.cuda.current_device() if torch.cuda.is_available() else None

# Returns a deep model for the reisz representer
def _get_learner(n_t, n_hidden, p):
    return nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_hidden), nn.LeakyReLU(),
                         nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.LeakyReLU(),
                         nn.Dropout(p=p), nn.Linear(n_hidden, 1))

# Returns a deep model for the test functions
def _get_adversary(n_z, n_hidden, p):
    return nn.Sequential(nn.Dropout(p=p), nn.Linear(n_z, n_hidden), nn.ReLU(),
                         nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ReLU(),
                         nn.Dropout(p=p), nn.Linear(n_hidden, 1))

def violation_fn(x, learner, adversary):
    return moment_fn(x, adversary) - learner(x) * adversary(x)

print("GPU:", torch.cuda.is_available())

In [None]:
torch.cuda.empty_cache()
Xt = torch.Tensor(X).to(device)
Xt_test = torch.Tensor(X_test).to(device)
At_test = torch.Tensor(true_reisz(X_test).reshape(-1, 1)).to(device)
learner = _get_learner(Xt.shape[1], n_hidden, drop_prob)
adversary = _get_adversary(Xt.shape[1], n_hidden, drop_prob)
agmm = DeepReisz(learner, adversary, moment_fn)

def logger(estimator, learner, adversary, epoch, writer):
    writer.add_histogram('learner', learner[-1].weight, epoch)
    writer.add_histogram('adversary', adversary[-1].weight, epoch)
    writer.add_scalar('moment', torch.mean(violation_fn(Xt, learner, adversary)), epoch)
    writer.add_scalar('moment_val', torch.mean(violation_fn(Xt_test, learner, adversary)), epoch)
    writer.add_scalar('true_mse_val', torch.mean((At_test - learner(Xt_test))**2)/4, epoch)
    if hasattr(estimator, 'curr_eval'):
        writer.add_scalar('approx_violation_val', estimator.curr_eval, epoch)

In [None]:
agmm.fit(Xt, Xval=Xt_test, preprocess_epochs=preprocess_epochs, earlystop_rounds=earlystop_rounds, store_test_every=store_test_every,
         learner_lr=learner_lr, adversary_lr=adversary_lr,
         learner_l2=learner_l2, adversary_l2=adversary_l2,
         n_epochs=n_epochs, bs=bs,
         logger=logger, model_dir=str(Path.home()), device=device, verbose=1)

# Evaluation of Learned Reisz Representer

In [None]:
plt.figure()
model = 'earlystop'
plt.scatter(true_reisz(X_test), agmm.predict(torch.Tensor(X_test).to(device), model=model))
plt.xlabel('true reisz')
plt.ylabel('estimated reisz')
plt.show()

In [None]:
for model in ['earlystop', 'avg']:
    plt.figure(figsize=(15, 5))
    error = true_reisz(X_test) - agmm.predict(torch.Tensor(X_test).to(device), model=model)
    plt.title("RMSE: {:.3f}, MAE: {:.3f}".format(np.sqrt(np.mean(error**2)), np.mean(np.abs(error))))
    for t in [0, 1]:
        plt.subplot(1, 2, t + 1)
        treated = (X_test[:, 0]==t)
        Xtreated = X_test[treated].copy()
        Xtreated[:, 1] = np.linspace(-1, 1, Xtreated.shape[0])
        if model == 'avg':
            point, lb, ub = agmm.predict(torch.Tensor(Xtreated).to(device), model=model, alpha=.05, burn_in=0)
            plt.plot(Xtreated[:, 1], point, label='est a({}, X)'.format(t))
            plt.fill_between(Xtreated[:, 1], lb, ub, alpha=.4)
        else:
            point = agmm.predict(torch.Tensor(Xtreated).to(device), model=model)
            plt.plot(Xtreated[:, 1], point, label='est a({}, X)'.format(t))
        plt.plot(Xtreated[:, 1], true_reisz(Xtreated), label='true a({}, X)'.format(t))
        plt.legend()
    plt.show()

In [None]:
# test how well the learned reisz representer approximates for f(T, X) = E[Y | T, X]
torch.mean(violation_fn(Xt, agmm.learner, lambda x: true_f(x).reshape(-1, 1)))

In [None]:
# test the same for the case of the true reisz representer
np.mean(violation_fn(X, lambda x: true_reisz(x).reshape(-1, 1), lambda x: true_f(x).reshape(-1, 1)))

In [None]:
# test the reisz representer for the final value of the learned test function
torch.mean(violation_fn(Xt, agmm.learner, agmm.adversary))

In [None]:
%tensorboard --logdir=runs

# Debiasing ATE

Applying the learned Reisz representer in order to debias a preliminary regression based model of the ATE

In [None]:
def mean_ci(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

In [None]:
from sklearn.linear_model import LassoCV, Lasso, LogisticRegressionCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures

est = Pipeline([('p', PolynomialFeatures(degree=3)), ('l', Lasso(alpha=.05))]).fit(X, y)
propensity = Pipeline([('p', PolynomialFeatures(degree=3)), ('l', LogisticRegressionCV())]).fit(X[:, 1:], X[:, 0])

In [None]:
a_test = agmm.predict(Xt_test, model='earlystop')
invp_test = true_reisz(X_test, lambda x: propensity.predict_proba(x[:, 1:])[:, 1])
true_ate = np.mean(moment_fn(X_test, true_f))
naive_ate = np.mean(y_test[X_test[:, 0]==1]) - np.mean(y_test[X_test[:, 0]==0])
biased_ate = np.mean(moment_fn(X_test, est.predict))
ips_ate = np.mean(invp_test * y_test)
reisz_ate = np.mean(a_test * y_test)
dr_ate = biased_ate + np.mean(invp_test * (y_test - est.predict(X_test)))
dr_reisz_ate, dr_reisz_low, dr_reisz_up = mean_ci(moment_fn(X_test, est.predict) + a_test * (y_test - est.predict(X_test)))

In [None]:
print("True ATE: {:.3f}".format(true_ate))
print("Mean of Treated - Mean of Untreated: {:.3f}".format(naive_ate))
print("Lasso Regression based Estimate: {:.3f}".format(biased_ate))
print("IPS estimate with explicit propensity based reisz estimate (i.e. mean(a(X) Y)): {:.3f}".format(ips_ate))
print("IPS estimate with adversarial reisz: {:.3f}".format(reisz_ate))
print("DR estimate with explict propensity based reisz estimate: {:.3f}".format(dr_ate))
print("DR estimate with adversarial reisz: {:.3f} ({:.3f}, {:.3f})".format(dr_reisz_ate, dr_reisz_low, dr_reisz_up))

In [None]:
dr_reisz_ate = []
reisz_ate = []
biased_ate = []
dr_reisz_cov = []
reisz_cov = []
for exp in range(1000):
    Xboot, yboot = gen_data(n, p)
    aboot = agmm.predict(Xboot, model='earlystop')
    drm, drl, dru = mean_ci(moment_fn(Xboot, est.predict) + aboot * (yboot - est.predict(Xboot)))
    rm, rl, ru = mean_ci(aboot * yboot)
    dr_reisz_ate.append(drm)
    reisz_ate.append(rm)
    dr_reisz_cov.append((true_ate <= dru) & (true_ate >= drl))
    reisz_cov.append((true_ate <= ru) & (true_ate >= rl))

In [None]:
plt.title("Mean advDR: {:.3f}, Coverage advDR: {:.3f}, "
          "Mean advIPS: {:.3f}, Coverage advIPS: {:.3f}".format(np.mean(dr_reisz_ate),
                                                               np.mean(dr_reisz_cov),
                                                               np.mean(reisz_ate),
                                                               np.mean(reisz_cov)))
plt.hist(np.array(dr_reisz_ate), alpha=.5, label='advDR')
plt.hist(np.array(reisz_ate), alpha=.5, label='advIPS')
plt.legend()
plt.show()