-
-
Notifications
You must be signed in to change notification settings - Fork 986
/
traceenum_elbo.py
134 lines (110 loc) · 5.45 KB
/
traceenum_elbo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import absolute_import, division, print_function
import warnings
import pyro
import pyro.infer as infer
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import iter_discrete_traces
from pyro.infer.util import Dice
from pyro.poutine import EnumerateMessenger
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_model_guide_match, check_site_shape, check_traceenum_requirements, torch_isnan
def _compute_dice_elbo(model_trace, guide_trace):
# y depends on x iff ordering[x] <= ordering[y]
# TODO refine this coarse dependency ordering.
ordering = {name: frozenset(f for f in site["cond_indep_stack"] if f.vectorized)
for name, site in model_trace.nodes.items()
if site["type"] == "sample"}
dice = Dice(guide_trace, ordering)
elbo = 0.0
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
cost = model_site["log_prob"]
if not model_site["is_observed"]:
cost = cost - guide_trace.nodes[name]["log_prob"]
dice_prob = dice.in_context(cost.shape, ordering[name])
# TODO use score_parts.entropy_term to "stick the landing"
elbo = elbo + (dice_prob * cost).sum()
return elbo
class TraceEnum_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI that supports enumeration
over discrete sample sites.
To enumerate over a sample site, the ``guide``'s sample site must specify
either ``infer={'enumerate': 'sequential'}`` or
``infer={'enumerate': 'parallel'}``. To configure all sites at once, use
:func:`~pyro.infer.enum.config_enumerate``.
This assumes restricted dependency structure on the model and guide:
variables outside of an :class:`~pyro.iarange` can never depend on
variables inside that :class:`~pyro.iarange`.
"""
def _get_traces(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a trace generator
"""
# enable parallel enumeration
guide = EnumerateMessenger(first_available_dim=self.max_iarange_nesting)(guide)
for i in range(self.num_particles):
for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
model_trace = poutine.trace(poutine.replay(model, guide_trace),
graph_type="flat").get_trace(*args, **kwargs)
if infer.is_validation_enabled():
check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
if infer.is_validation_enabled():
check_traceenum_requirements(model_trace, guide_trace)
model_trace.compute_log_prob()
guide_trace.compute_score_parts()
if infer.is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
yield model_trace, guide_trace
def loss(self, model, guide, *args, **kwargs):
"""
:returns: returns an estimate of the ELBO
:rtype: float
Estimates the ELBO using ``num_particles`` many samples (particles).
"""
elbo = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
if is_identically_zero(elbo_particle):
continue
elbo += elbo_particle.item() / self.num_particles
loss = -elbo
if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss
def loss_and_grads(self, model, guide, *args, **kwargs):
"""
:returns: returns an estimate of the ELBO
:rtype: float
Estimates the ELBO using ``num_particles`` many samples (particles).
Performs backward on the ELBO of each particle.
"""
elbo = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
if is_identically_zero(elbo_particle):
continue
elbo += elbo_particle.item() / self.num_particles
# collect parameters to train from model and guide
trainable_params = set(site["value"].unconstrained()
for trace in (model_trace, guide_trace)
for site in trace.nodes.values()
if site["type"] == "param")
if trainable_params and elbo_particle.requires_grad:
loss_particle = -elbo_particle
(loss_particle / self.num_particles).backward()
pyro.get_param_store().mark_params_active(trainable_params)
loss = -elbo
if torch_isnan(loss):
warnings.warn('Encountered NAN loss')
return loss