In [199]:
%load_ext autoreload

In [200]:
%autoreload 2

# Pyro Effect Handler

Pyro effect handlers are used extensively to estimate 

In [100]:
import seaborn as sns
import torch

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine


from pyro.poutine.runtime import effectful

pyro.set_rng_seed(101)

In [101]:
def scale(guess):
    weight = pyro.sample('weight', dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

In [102]:
guess = 10.

# Log_joint_fn Generation
We can compute a log_joint function for a given model using ``poutine.condition`` and ``poutine.trace``.

In [103]:
conditioned_scale = pyro.condition(scale, data={'weight': 9.5, 'measurement': 10.})

In [104]:
guess = 10.
a_trace = poutine.trace(conditioned_scale).get_trace(guess)

In [105]:
a_trace.log_prob_sum() == dist.Normal(guess, 1.0).log_prob(9.5) + dist.Normal(9.5, 0.75).log_prob(10.)

tensor(1, dtype=torch.uint8)

In [106]:
def make_log_joint2(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

scale_log_joint = make_log_joint(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))

tensor(-3.0203)


We can go into more details

In [109]:
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

In [124]:
def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                model(*args, **kwargs)
                
        trace = tracer.trace
        logp = 0.
        for name, node in trace.nodes.items():
            if node['type'] == 'sample':
                if node['is_observed']:
                    assert node['value'] is cond_data[name]
                logp = logp + node['fn'].log_prob(node['value']).sum()
        return logp
    return _log_joint

In [125]:
scale_log_joint = make_log_joint_2(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))

tensor(-3.0203)


In [133]:
print(scale_log_joint({"measurement": torch.tensor([9.5, 10.5]), "weight": torch.tensor([8.23, 9.23])}, 8.5))

tensor(-6.2707)


In [139]:
log_prob_1 = scale_log_joint({"measurement": torch.tensor([9.5]), "weight": torch.tensor([8.23])}, 8.5)

In [140]:
log_prob_2 = scale_log_joint({"measurement": torch.tensor([10.5]), "weight": torch.tensor([9.23])}, 8.5)

In [141]:
log_prob_1 + log_prob_2

tensor(-6.2707)

In [147]:
a_trace = poutine.trace(pyro.condition(scale, data={"measurement": torch.tensor([9.5]), "weight": torch.tensor([8.23])})).get_trace(8.5) # guess = 8.5

In [150]:
a_trace.log_prob_sum() == log_prob_1

tensor(1, dtype=torch.uint8)

# New Messenger

In [185]:
class LogJointMessenger(poutine.messenger.Messenger):
    
    def __init__(self, cond_data):
        self.data = cond_data
        
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn
    
    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super(LogJointMessenger, self).__enter__()
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super(LogJointMessenger, self).__exit__(exc_type, exc_value, traceback)
    
    def _pyro_sample(self, msg):
        assert msg['name'] in self.data
        msg['value'] = self.data[msg['name']]
        msg['is_observed'] = True
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()


In [186]:
with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp.clone())

tensor(-3.0203)


In [187]:
scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))

tensor(-3.0203)


In [183]:
xs = dist.Bernoulli(torch.tensor([0.2, 0.3, 0.9]))

In [184]:
xs_samples = xs.sample()

In [175]:
xs.log_prob(xs_samples)

tensor([-1.6094, -0.3567, -0.1054])

In [177]:
xs.log_prob(xs_samples).sum()

tensor(-2.0715)

# Experiment

In [23]:
import pyro.distributions as dist

In [31]:
dist.Normal?

[0;31mInit signature:[0m [0mdist[0m[0;34m.[0m[0mNormal[0m[0;34m([0m[0mloc[0m[0;34m,[0m [0mscale[0m[0;34m,[0m [0mvalidate_args[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Wraps :class:`torch.distributions.normal.Normal` with
:class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`.
[0;31mFile:[0m           ~/miniconda3/envs/spectrumdev/lib/python3.6/site-packages/pyro/distributions/torch.py
[0;31mType:[0m           ABCMeta
[0;31mSubclasses:[0m     


In [24]:
import minipyro as pyro

In [25]:
pyro.PARAM_STORE

{}

In [26]:
pyro.PYRO_STACK

[]

In [30]:
def model():
    pyro.sample('normal_rv', dist.Normal(1.,1.))

trace = pyro.trace(model)

trace.get_trace()

ok running apply stack


OrderedDict([('normal_rv',
              {'type': 'sample',
               'name': 'normal_rv',
               'fn': Normal(loc: 1.0, scale: 1.0),
               'args': (),
               'value': tensor(1.0842)})])

In [22]:
trace

<minipyro.trace at 0xa24d95e10>