Permalink
Browse files

make target_log_prob_fn to use class based interceptor

  • Loading branch information...
sharanry committed Jul 6, 2018
1 parent 9f46878 commit f223e4ed9042f020c7799ca64c285d24387be630
Showing with 15 additions and 23 deletions.
  1. +1 −0 pymc4/__init__.py
  2. +2 −14 pymc4/model/base.py
  3. +12 −9 pymc4/util/interceptors.py
View
@@ -8,4 +8,5 @@
from .inference import (
sampling
)
from .inference.sampling.sample import sample
from . import util
View
@@ -83,23 +83,11 @@ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument
states = dict(zip(self.unobserved.keys(), args))
states.update(self.observed)
log_probs = []
def interceptor(f, *args, **kwargs):
name = kwargs.get("name")
for name in states:
value = states[name]
if kwargs.get("name") == name:
kwargs["value"] = value
rv = f(*args, **kwargs)
log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
log_probs.append(log_prob)
return rv
interceptor = interceptors.CollectLogProb(states)
with ed.interception(interceptor):
self._f(self._cfg)
log_prob = sum(log_probs)
log_prob = sum(interceptor.log_probs)
return log_prob
return log_joint_fn
View
@@ -109,22 +109,25 @@ def after(self, rv, *args, **kwargs):
class CollectLogProb(SetState):
def __init__(self, state):
super().__init__(state)
with self.name_scope():
self._result = tf.constant(0.)
def __init__(self, states):
super().__init__(states)
self.log_probs = []
def before(self, f, *args, **kwargs):
if kwargs['name'] not in self.state:
raise RuntimeError(kwargs.get('name'), 'All RV should be present in state dict')
return super().before(f, *args, **kwargs)
def after(self, rv, *args, **kwargs):
with self.name_scope():
log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
self._result += log_prob
name = kwargs.get("name")
for name in self.state:
value = self.state[name]
if kwargs.get("name") == name:
kwargs["value"] = value
log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
self.log_probs.append(log_prob)
return rv
@property
def result(self):
with self.name_scope():
return tf.identity(self._result, 'result')
return self.log_probs

0 comments on commit f223e4e

Please sign in to comment.