In [None]:
from torch import tensor

import pyro
from pyro import condition, do, sample
from pyro.infer import Importance, EmpiricalMarginal
from pyro.distributions import Bernoulli as Flip
from pyro.distributions import Delta

from matplotlib import pyplot as plt
%matplotlib inline

# Goal of this version

The goal here is to start working on key abstractions.

# Toy Model:  Blindness Clinical Trial

## Problem
1.  Everybody has a blindness disease.
2.  We randomize people to control (T=0) or treatment (T=1)
3.  Generally people who DON'T get treatment (T=0) go blind (B=1) the next day.  If they do get treatment (T=1), then they don't go blind the next day (B=0).
4. 1% of the population has this rare genotype where if they are treated (T=1) then they DO go blind the next day (B=1). And if they are not treated (T=0) then they DON'T go blind. 

### Reasoning question: 
A person was treated and they went blind.  Would they have gone blind if that hand not been treated?

$P_M^{T=1, B=1, do(T=0)}$

### Structural Causal Model approach to probabilistic programming

Parameterize model by splitting into the noise variables and endogenous variables.  Noise variables are the only random variabls.  Endogenous variables are deterministic functions of the random vars.

*Randomized assignment*
$N_t \sim \text{Benoulli}(0.5)$ 

*Population prior on genotype.*
$N_b \sim \text{Benoulli}(0.01)$ 

*Treatment state*
$T \sim \text{Delta}(N_t)$

*Blindness outcome*
$B \sim \text{Delta}(T * N_b + (1-T)(1-N_b)$


Implement the clinical trial model.

In [None]:
def rct(NB_dist):
    NT = sample('NT', Flip(.5))
    NB = sample('NB', NB_dist)
    T = sample('T', Delta(NT))
    B = sample('B', Delta(T * NB + (1-T) * (1-NB)))
    return B


## Compute the marginal on B

In [None]:
dist = Importance(rct, num_samples=1000)
marginal = EmpiricalMarginal(dist.run(Flip(.01)), sites = 'B')
plt.hist([marginal().item() for _ in range(100)], range=(0.0, 1.0))
plt.title("P(B)")
plt.xlabel("B")
plt.ylabel("#");

## Compute the counterfactual

Marginal on what B would have been if "do(T = 0)", conditional on observing T = 1 & B = 1

1. Obtain the posterior on the noise distribution, conditional on evidence T = 1 & B = 1.

In [None]:
conditioned_model = condition(rct, data={"T": tensor(1.), "B": tensor(1.), })
conditional_dist = Importance(conditioned_model, num_samples=1000)
NB_marginal = EmpiricalMarginal(conditional_dist.run(Flip(.01)), sites = 'NB')

2. Obtain the interventional distribution for T = 0.

In [None]:
intervention_model = do(rct, data={"T": tensor(0.)})
intervention_dist = Importance(intervention_model, num_samples=1000)

3. Pass the updated noise distribution to the intervention distribution.

In [None]:
counterfactual_dist = EmpiricalMarginal(intervention_dist.run(NB_marginal), sites = 'B')

In [None]:
plt.hist([counterfactual_dist().item() for _ in range(100)], range=(0.0, 1.0))
plt.title("P(B | observe(T = 1 & B = 1), do(T = 0)")
plt.xlabel("B")
plt.ylabel("#")