# Causal Effect VAE with Pyro
reference: https://pyro.ai/examples/cevae.html

In [1]:
import argparse
import logging

In [2]:
import torch

In [4]:
#!pip3 install pyro-ppl

Collecting pyro-ppl
[?25l  Downloading https://files.pythonhosted.org/packages/f9/85/c294ab8062bef14aac4f87770f22a0b7a5cd9204d0777c6eb37fa4b9320d/pyro_ppl-1.4.0-py3-none-any.whl (573kB)
[K     |████████████████████████████████| 573kB 3.5MB/s 
[?25hCollecting pyro-api>=0.1.1
  Downloading https://files.pythonhosted.org/packages/fc/81/957ae78e6398460a7230b0eb9b8f1cb954c5e913e868e48d89324c68cec7/pyro_api-0.1.2-py3-none-any.whl
Installing collected packages: pyro-api, pyro-ppl
Successfully installed pyro-api-0.1.2 pyro-ppl-1.4.0


In [5]:
import pyro

In [9]:
assert pyro.__version__.startswith('1.4.0')

In [6]:
import pyro.distributions as dist

In [7]:
from pyro.contrib.cevae import CEVAE

In [8]:
logging.getLogger("pyro").setLevel(logging.DEBUG)
logging.getLogger("pyro").handlers[0].setLevel(logging.DEBUG)

In [32]:
class ARGS():
  def __init__(self):
      self.num_data=1000
      self.feature_dim=5
      self.latent_dim=20
      self.hidden_dim=200
      self.num_layers=3
      self.num_epochs=50
      self.batch_size=100
      self.learning_rate=1e-3
      self.learning_rate_decay=0.1
      self.weight_decay=1e-4
      self.seed=1234567890
      self.jit=True
      self.cuda=True

In [None]:
args = ARGS()

In [37]:
def generate_data(args):
    """
    This implements the generative process of [1], but using larger feature and
    latent spaces ([1] assumes ``feature_dim=1`` and ``latent_dim=5``).
    """
    z = dist.Bernoulli(0.5).sample([args.num_data])
    x = dist.Normal(z, 5 * z + 3 * (1 - z)).sample([args.feature_dim]).t()
    t = dist.Bernoulli(0.75 * z + 0.25 * (1 - z)).sample()
    y = dist.Bernoulli(logits=3 * (z + 2 * (2 * t - 2))).sample()

    # Compute true ite for evaluation (via Monte Carlo approximation).
    t0_t1 = torch.tensor([[0.], [1.]])
    y_t0, y_t1 = dist.Bernoulli(logits=3 * (z + 2 * (2 * t0_t1 - 2))).mean
    true_ite = y_t1 - y_t0
    return x, t, y, true_ite

In [38]:
pyro.enable_validation(__debug__)

In [39]:
if args.cuda:
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [40]:
# Generate synthetic data.
pyro.set_rng_seed(args.seed)
x_train, t_train, y_train, _ = generate_data(args)

In [42]:
x_train.shape

torch.Size([1000, 5])

In [44]:
y_train.shape

torch.Size([1000])

In [45]:
# Train.
pyro.set_rng_seed(args.seed)
pyro.clear_param_store()

cevae = CEVAE(feature_dim=args.feature_dim,
              latent_dim=args.latent_dim,
              hidden_dim=args.hidden_dim,
              num_layers=args.num_layers,
              num_samples=10)

cevae.fit(x_train, t_train, y_train,
          num_epochs=args.num_epochs,
          batch_size=args.batch_size,
          learning_rate=args.learning_rate,
          learning_rate_decay=args.learning_rate_decay,
          weight_decay=args.weight_decay)

INFO 	 Training with 10 minibatches per epoch
DEBUG 	 step     0 loss = 13.9879
DEBUG 	 step     1 loss = 13.4347
DEBUG 	 step     2 loss = 11.9022
DEBUG 	 step     3 loss = 10.9596
DEBUG 	 step     4 loss = 10.679
DEBUG 	 step     5 loss = 10.3065
DEBUG 	 step     6 loss = 10.4947
DEBUG 	 step     7 loss = 10.0073
DEBUG 	 step     8 loss = 9.91819
DEBUG 	 step     9 loss = 9.25705
DEBUG 	 step    10 loss = 8.96875
DEBUG 	 step    11 loss = 9.22228
DEBUG 	 step    12 loss = 9.74679
DEBUG 	 step    13 loss = 9.69091
DEBUG 	 step    14 loss = 9.32762
DEBUG 	 step    15 loss = 9.48919
DEBUG 	 step    16 loss = 9.40744
DEBUG 	 step    17 loss = 9.59345
DEBUG 	 step    18 loss = 9.72686
DEBUG 	 step    19 loss = 9.5217
DEBUG 	 step    20 loss = 9.67367
DEBUG 	 step    21 loss = 9.70313
DEBUG 	 step    22 loss = 9.18544
DEBUG 	 step    23 loss = 9.17808
DEBUG 	 step    24 loss = 9.17006
DEBUG 	 step    25 loss = 9.18559
DEBUG 	 step    26 loss = 8.9947
DEBUG 	 step    27 loss = 9.40237
DEBUG

[13.987886169433594,
 13.434739135742188,
 11.902153564453124,
 10.959574279785157,
 10.678965240478515,
 10.306456726074218,
 10.49474185180664,
 10.007325927734374,
 9.918187561035156,
 9.257052581787109,
 8.968751525878906,
 9.222281005859376,
 9.746790405273437,
 9.690911865234375,
 9.327623184204102,
 9.489188262939454,
 9.407439483642579,
 9.593446746826173,
 9.726864517211913,
 9.521697418212892,
 9.673669647216796,
 9.703134918212891,
 9.1854423828125,
 9.178078079223633,
 9.170055541992188,
 9.185591186523437,
 8.994696746826172,
 9.402368194580077,
 8.78461604309082,
 9.513079696655273,
 9.271586868286132,
 9.277921173095702,
 9.534161437988281,
 8.809894348144532,
 9.147845886230469,
 9.222864715576172,
 9.424579040527345,
 9.440942657470703,
 8.881433197021485,
 9.061224319458008,
 9.699822875976562,
 9.089233520507813,
 8.939717864990234,
 9.313008544921875,
 9.11053709411621,
 8.934159637451172,
 9.439756210327149,
 9.133090698242187,
 9.023703140258789,
 8.94488186645507

In [46]:
# Evaluate.
x_test, t_test, y_test, true_ite = generate_data(args)
true_ate = true_ite.mean()
print("true ATE = {:0.3g}".format(true_ate.item()))
naive_ate = y_test[t_test == 1].mean() - y_test[t_test == 0].mean()
print("naive ATE = {:0.3g}".format(naive_ate))
if args.jit:
    cevae = cevae.to_script_module()
est_ite = cevae.ite(x_test)
est_ate = est_ite.mean()
print("estimated ATE = {:0.3g}".format(est_ate.item()))

INFO 	 Evaluating 1 minibatches


true ATE = 0.723
naive ATE = 0.834
estimated ATE = 0.813
