$$
\begin{align} 
\mathbf{M} =
\left\{\begin{matrix}
n_x &\sim &\text{BernoulliBool}(p=0.5433931)\\ 
n_q &\sim &\text{BernoulliBool}(p=0.05)\\ 
n_y &\sim &\text{BernoulliBool}(p=.0077)\\ 
x &= &n_x \\ 
q &= &n_q \\
y &= &(x \wedge q) \vee n_y
\\ 
\end{matrix}\right. \nonumber
\end{align}
$$

In [1]:
import pyro
import torch
import pyro.distributions as dist

In [4]:
def f(x, q, n_y):
    return (x and q) or n_y

def model(noise):
    n_x = pyro.sample('n_x', noise['X'])
    n_q = pyro.sample('n_q', noise['Q'])
    n_y = pyro.sample('n_y', noise['Y'])
    
    X = pyro.sample('X', dist.Delta(n_x))
    Q = pyro.sample('Q', dist.Delta(n_q))
    Y = pyro.sample('Y', dist.Delta(f(X, Q, n_y)))
    
    return {
        'X': X,
        'Q': Q,
        'Y': Y
    }

In [24]:
noise = {
    'X': dist.Bernoulli(torch.tensor(0.5433931)),
    'Q': dist.Bernoulli(torch.tensor(0.05)),
    'Y': dist.Bernoulli(torch.tensor(0.0077))
}

model(noise)

{'X': tensor(0.), 'Q': tensor(0.), 'Y': tensor(0.)}

In [25]:
def CFQuery(model, noise, obs_x, int_x, ssize):
    """
    Performs Counterfactual Inference on model.
    
    `model`: the subroutine encoding the SCM.
    `noise`: a dictionary containing distributions for each 
    noise object.
    `obs_x`: a dictionary containing observed values for each 
    variable in the model.
    `int_x`: a dictionary containing values for a subset of 
    variables for intervention.
    `ssize`: # samples to return from the posterior of 
    counterfactual distribution.
    """
    
    # Condition on observed outcome
    obs_mod = pyro.condition(model, data = obs_x)
    
    # Set intervention for counterfactual outcome 
    int_mod = pyro.do(model, data = int_x)
    
    # Infer noise give observed outcome
    NPstr = infer_dist(obs_mod, noise)
    
    noise_posteriors = compute_posteriors(NPstr, noise)
        
    # Generate CF outcome by passing updated
    # noise to intervention model
    samples = [int_mod(noise_posteriors,
                       thetaX1 = torch.Tensor([0.5]),
                       thetaX2 = torch.Tensor([0.8])) 
               for _ in range(ssize)] 
    return samples

In [None]:
CFQuery(model, noise, )