In [1]:
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

import pandas as pd
import numpy as np
import torch

In [2]:
def ProposedModelSmall(a=None, r=None, y=None, gamma_shift=0., C=0.):
    gamma_shift = torch.tensor(gamma_shift)
    C = torch.tensor(C)

    lambda1 = pyro.sample("lambda1", dist.Normal(0., 1.))
    lambda2 = pyro.sample("lambda2", dist.Normal(0., 1.))
    lambda3 = pyro.sample("lambda3", dist.Normal(0., 1.))
    lambda4 = pyro.sample("lambda4", dist.Normal(0., 1.))
    
    with pyro.plate("data", a.size(0)):
        u1 = pyro.sample("u1", dist.Normal(torch.tensor(0.), torch.tensor(0.6)))
        u2 = pyro.sample("u2", dist.Normal(torch.tensor(0.), torch.tensor(0.8)))
        u3 = pyro.sample("u3", dist.Normal(torch.tensor(0.), torch.tensor(0.9)))
        
        a_prob = torch.sigmoid(gamma_shift*lambda1*C + u1)
        a_val = pyro.sample("a", dist.Bernoulli(probs=a_prob), obs=a)

        r_val = pyro.sample("r", dist.Normal(lambda2*a_val + u2, 1.), obs=r) 

        y_prob = torch.sigmoid(lambda3*a_val + lambda4*r_val + u3)
        y_val = pyro.sample("y", dist.Bernoulli(probs=y_prob), obs=y)
        
        return a_val, r_val, y_val

In [3]:
# Sample data
samples = pd.DataFrame(columns=['A', 'R', 'Y'])
samples['A'] = [1., 0., 1.]
samples['R'] = [1.25, 0.87, 0.43]
samples['Y'] = [0., 1., 1.]

In [7]:
def infer_exogenous(obs, model):
    input_a = torch.tensor(np.array([obs['a']]))
    input_r = torch.tensor(np.array([obs['r']]))
    input_y = torch.tensor(np.array([obs['y']]))
    
    cond_sample = pyro.condition(lambda: model(input_a, input_r, input_y), obs)
    cond_trace = pyro.poutine.trace(cond_sample).get_trace()
    
    exogenous = {k: cond_trace.nodes[k]['value'] for k in  ['u1', 'u2', 'u3']}
    return exogenous


def counterfactual(model, obs, learned_params):
    # Infer state of world (ie learn noise)
    exogenous = infer_exogenous(obs, model)
    exogenous_and_learned = {**exogenous, **learned_params}
    
    # Find counterfactual A value (a') for this sample
    input_a = torch.tensor(np.array([obs['a']]))
    cf_a = 0 if input_a.numpy()[0] == 1. else 1.
    
    # Compute counterfactual sample
    cf_model = pyro.do(pyro.condition(model, data=exogenous_and_learned), data={'a': torch.tensor(cf_a)})
    
    return cf_model, cf_a, state_of_world

In [8]:
# Get sample
obs = {k.lower(): torch.tensor(samples.iloc[0][k]) for k in ['A', 'R', 'Y']} 

# Infer noise/state of the world and generate counterfactual
learned_params = {'lambda1': torch.tensor(8.68e-24), 'lambda2': torch.tensor(-0.15), 'lambda3': torch.tensor(-0.81), 'lambda4': torch.tensor(0.85)}
cf, cf_a, state_of_world = counterfactual(ProposedModelSmall, obs, learned_params)

input_a = torch.tensor(np.array([obs['a']]))
print(cf(input_a))

(tensor(0), tensor([-1.4738]), tensor([0.]))


In [6]:
obs

{'a': tensor(1., dtype=torch.float64),
 'r': tensor(1.2500, dtype=torch.float64),
 'y': tensor(0., dtype=torch.float64)}