Skip to content

Commit

Permalink
Add model.target_log_prob_fn()
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Jun 18, 2018
1 parent bb4de21 commit a703c21
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
19 changes: 19 additions & 0 deletions pymc4/model/base.py
@@ -1,3 +1,4 @@
import collections
import biwrap
import tensorflow as tf
from tensorflow_probability import edward2 as ed
Expand Down Expand Up @@ -67,6 +68,14 @@ def get_mode(state, rv, *args, **kwargs):
returns = self.session.run(list(values_collector.result.values()))
return dict(zip(values_collector.result.keys(), returns))

def target_log_prob_fn(self, *args, **kwargs):
logp = 0
for i in self.unobserved.keys():
print(kwargs.get(i))
logp += self.unobserved[i].rv.distribution.log_prob(value=kwargs.get(i))

return logp

def observe(self, **observations):
self._observed = observations
return self
Expand All @@ -83,6 +92,16 @@ def graph(self):
def observed(self):
return self._observed

@property
def unobserved(self):
unobserved = {}
for i in self.variables:
if self.variables[i] not in self.observed.values():
unobserved[i] = self.variables[i]

unobserved = collections.OrderedDict(unobserved)
return unobserved

@property
def variables(self):
return self._variables
Expand Down
4 changes: 2 additions & 2 deletions pymc4/util/interceptors.py
Expand Up @@ -9,7 +9,7 @@
'CollectLogProb'
]

VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape')
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv')


class Interceptor(object):
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self):
def after(self, rv, *args, **kwargs):
name = kwargs["name"]
if name not in self.result:
self.result[name] = VariableDescription(rv.distribution.__class__, rv.shape)
self.result[name] = VariableDescription(rv.distribution.__class__, rv.shape, rv)
else:
raise KeyError(name, 'Duplicate name')
return rv
Expand Down

0 comments on commit a703c21

Please sign in to comment.