/
mh.py
115 lines (103 loc) · 4.08 KB
/
mh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import random
import torch
from torch.autograd import Variable
import pyro
import pyro.poutine as poutine
from pyro.distributions import Uniform
from pyro.infer import TracePosterior
class MH(TracePosterior):
"""
Initial implementation of MH MCMC
"""
def __init__(self, model, guide=None, proposal=None, samples=10, lag=1, burn=0):
super(MH, self).__init__()
self.samples = samples
self.lag = lag
self.burn = burn
self.model = model
assert (guide is None or proposal is None) and \
(guide is not None or proposal is not None), \
"requires exactly one of guide or proposal, not both or none"
if guide is not None:
self.guide = lambda tr, *args, **kwargs: guide(*args, **kwargs)
else:
self.guide = proposal
def _traces(self, *args, **kwargs):
"""
make trace posterior distribution
"""
# initialize traces with a draw from the prior
old_model_trace = poutine.trace(self.model)(*args, **kwargs)
traces = []
t = 0
i = 0
while t < self.burn + self.lag * self.samples:
i += 1
# q(z' | z)
new_guide_trace = poutine.block(
poutine.trace(self.guide))(old_model_trace, *args, **kwargs)
# p(x, z')
new_model_trace = poutine.trace(
poutine.replay(self.model, new_guide_trace))(*args, **kwargs)
# q(z | z')
old_guide_trace = poutine.block(
poutine.trace(
poutine.replay(self.guide, old_model_trace)))(new_model_trace,
*args, **kwargs)
# p(x, z') q(z' | z) / p(x, z) q(z | z')
logr = new_model_trace.log_pdf() + new_guide_trace.log_pdf() - \
old_model_trace.log_pdf() - old_guide_trace.log_pdf()
rnd = pyro.sample("mh_step_{}".format(i),
Uniform(pyro.zeros(1), pyro.ones(1)))
if torch.log(rnd).data[0] < logr.data[0]:
# accept
t += 1
old_model_trace = new_model_trace
if t <= self.burn or (t > self.burn and t % self.lag == 0):
yield (new_model_trace, new_model_trace.log_pdf())
##############################################
# MH subclasses and helpers
##############################################
def single_site_proposal(model):
def _fn(tr, *args, **kwargs):
choice_name = random.choice(
[s for s in tr.keys() if tr[s]["type"] == "sample"])
return pyro.sample(choice_name,
tr[choice_name]["fn"],
*tr[choice_name]["args"][0],
**tr[choice_name]["args"][1])
return _fn
class SingleSiteMH(MH):
def __init__(self, model, **kwargs):
super(SingleSiteMH, self).__init__(
model, guide=None, proposal=single_site_proposal(model), **kwargs)
# def hmc_proposal(model, sites=None):
# def _fn(tr, *args, **kwargs):
# for i in range(steps):
# tr = poutine.block(
# poutine.trace(poutine.replay(model, tr, sites=sites)))(
# *args, **kwargs)
# logp = tr.log_pdf()
# samples = [tr[name]["value"] for name in tr.keys() \
# if tr[name]["type"] == "sample"]
# autograd.backward(samples, logp)
# optimizer.step(samples)
# return tr
# return _fn
#
#
# class HMC(MH):
# def __init__(self, model, **kwargs):
# super(HMC, self).__init__(
# model, guide=None, proposal=hmc_guide(model), **kwargs)
#
#
# def mixture_guide(guides):
# return lambda *args, **kwargs: guides[pyro.sample(gensym(), discrete, guides, ones())](*args, **kwargs)
#
#
# class MixedHMCMH(MH):
# def __init__(self, model):
# proposal = mixture_guide([hmc_proposal(model),
# single_site_proposal(model)])
# super(MixedHMCMH, self).__init__(model, proposal=proposal)