In [1]:
%matplotlib inline
import torch
import pyro
import pyro.infer
import pyro.distributions as dist
import numpy as np
import matplotlib.pyplot as plt

from torch.autograd import Variable

# Introduction

(This tutorial is largely adapted from the disease example in [Chapter 4 of ProbMods](http://probmods.org/chapters/04-patterns-of-inference.html))

# Conditioning and intervention

In addition to `pyro.condition`, there is `pyro.do`.

In [2]:
def disease():
    # risk from smoking
    smokes = pyro.sample("smokes", dist.bernoulli, Variable(torch.Tensor([0.2])))
    
    # lung disease
    lung_disease_prior = pyro.sample("lung_disease_prior", dist.bernoulli, Variable(torch.Tensor([0.001])))
    lung_disease = (lung_disease_prior == 1.0).all()
    if (smokes == 1.0).all():
        lung_disease_smokes = pyro.sample("lung_disease_smokes", dist.bernoulli, Variable(torch.Tensor([0.1])))
        lung_disease = lung_disease or (lung_disease_smokes == 1.0).all()
    
    # confounding disease
    cold = pyro.sample("cold", dist.bernoulli, Variable(torch.Tensor([0.02])))
    
    # symptoms:
    # coughing
    cough_prior = pyro.sample("cough_prior", dist.bernoulli, Variable(torch.Tensor([0.01])))
    cough = (cough_prior == 1.0).all()
    if (cold == 1.0).all():
        cough_cold = pyro.sample("cough_cold", dist.bernoulli, Variable(torch.Tensor([0.5])))
        cough = cough or (cough_cold == 1.0).all()
    if lung_disease:
        cough_lung = pyro.sample("cough_lung", dist.bernoulli, Variable(torch.Tensor([0.5])))
        cough = cough or (cough_lung == 1.0).all()
        
    # fever
    fever_prior = pyro.sample("fever_prior", dist.bernoulli, Variable(torch.Tensor([0.01])))
    fever = (fever_prior == 1.0).all()
    if (cold == 1.0).all():
        fever_cold = pyro.sample("fever_cold", dist.bernoulli, Variable(torch.Tensor([0.3])))
        fever = fever or (fever_cold == 1.0).all()
        
    # chest pain
    chest_pain_prior = pyro.sample("chest_pain_prior", dist.bernoulli, Variable(torch.Tensor([0.01])))
    chest_pain = (chest_pain_prior == 1.0).all()
    if lung_disease:
        chest_pain_lung = pyro.sample("chest_pain_lung", dist.bernoulli, Variable(torch.Tensor([0.2])))
        chest_pain = chest_pain or (chest_pain_lung == 1.0).all()

    # here we add some meaningless sample statements that allow us to constrain the observable symptoms
    # warning: this will only work with pyro.infer.Search; other methods should use dist.bernoulli
    pd = lambda b: Variable(torch.Tensor([1])) if b else Variable(torch.Tensor([0]))
    pyro.sample("lung_disease", dist.bernoulli, pd(lung_disease))
    pyro.sample("cough", dist.bernoulli, pd(cough))
    pyro.sample("fever", dist.bernoulli, pd(fever))
    pyro.sample("chest_pain", dist.bernoulli, pd(chest_pain))
    
    # return the symptoms
    return {
        "cough": cough,
        "fever": fever,
        "chest_pain": chest_pain,
    }

In [3]:
conditioned_disease = pyro.condition(disease, data={"cough": Variable(torch.ones(1))})
cough_posterior = pyro.infer.Search(conditioned_disease)

lung_marginal = pyro.infer.Marginal(cough_posterior, sites=["lung_disease"])
print(lung_marginal())

cold_marginal = pyro.infer.Marginal(cough_posterior, sites=["cold"])
print(cold_marginal())

{'lung_disease': Variable containing:
 1
[torch.FloatTensor of size 1]
}
{'cold': Variable containing:
 0
[torch.FloatTensor of size 1]
}


In [5]:
print(torch.exp(cold_marginal.log_pdf({"cold": Variable(torch.zeros(1))})))
dd, vvs = cold_marginal._dist_and_values()
print(dd.ps, vvs)

Variable containing:
 0.6619
[torch.FloatTensor of size 1]

Variable containing:
 0.6619
 0.3381
[torch.FloatTensor of size 2]
 [{'cold': Variable containing:
 0
[torch.FloatTensor of size 1]
}, {'cold': Variable containing:
 1
[torch.FloatTensor of size 1]
}]


In [11]:
intervened_cold = pyro.do(disease, data={"cold": Variable(torch.ones(1))})
intervened_no_cold = pyro.do(disease, data={"cold": Variable(torch.zeros(1))})

marginal_intervened_cold = pyro.infer.Marginal(pyro.infer.Search(intervened_cold), sites=["cough"])
marginal_intervened_no_cold = pyro.infer.Marginal(pyro.infer.Search(intervened_no_cold), sites=["cough"])

p_cough_do_cold = torch.exp(marginal_intervened_cold.log_pdf({"cough": Variable(torch.ones(1))}))
p_cough_do_no_cold = torch.exp(marginal_intervened_no_cold.log_pdf({"cough": Variable(torch.ones(1))}))

print(p_cough_do_cold, p_cough_do_no_cold)

Variable containing:
 0.5102
[torch.FloatTensor of size 1]
 Variable containing:
1.00000e-02 *
  2.0386
[torch.FloatTensor of size 1]



In [16]:
intervened_cough = pyro.do(disease, data={"cough": Variable(torch.ones(1))})
intervened_no_cough = pyro.do(disease, data={"cough": Variable(torch.zeros(1))})
marginal_cold = pyro.infer.Marginal(pyro.infer.Search(disease), sites=["cold"])
marginal_cold_do_cough = pyro.infer.Marginal(pyro.infer.Search(intervened_cough), sites=["cold"])
marginal_cold_do_no_cough = pyro.infer.Marginal(pyro.infer.Search(intervened_no_cough), sites=["cold"])

p_cold = torch.exp(marginal_cold.log_pdf({"cold": Variable(torch.ones(1))}))
p_cold_do_cough = torch.exp(marginal_cold_do_cough.log_pdf({"cold": Variable(torch.ones(1))}))
p_cold_do_no_cough = torch.exp(marginal_cold_do_no_cough.log_pdf({"cold": Variable(torch.ones(1))}))

print(p_cold, p_cold_do_cough, p_cold_do_no_cough)

Variable containing:
1.00000e-02 *
  2.0000
[torch.FloatTensor of size 1]
 Variable containing:
1.00000e-02 *
  2.0000
[torch.FloatTensor of size 1]

