Skip to content

Commit

Permalink
restructure + test point implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Jun 10, 2018
1 parent e334115 commit d07338e
Show file tree
Hide file tree
Showing 16 changed files with 234 additions and 244 deletions.
9 changes: 8 additions & 1 deletion pymc4/__init__.py
@@ -1,4 +1,11 @@
__version__ = "0.0.1"
from . import model
from . import sampling
from .model import (
Model,
inline
)
from . import inference
from .inference import (
sampling
)
from . import util
Empty file added pymc4/distributions/__init__.py
Empty file.
19 changes: 19 additions & 0 deletions pymc4/distributions/base.py
@@ -0,0 +1,19 @@
import tensorflow as tf
from tensorflow_probability import edward2 as ed


__all__ = [
'Input'
]


class InputDistribution(tf.contrib.distributions.Deterministic):
"""
detectable class for input
is input <==> isinstance(rv.distribution, InputDistribution)
"""


def Input(name, shape, dtype=None):
return ed.as_random_variable(InputDistribution(name=name, shape=shape, dtype=dtype))

Empty file added pymc4/inference/__init__.py
Empty file.
Empty file.
31 changes: 31 additions & 0 deletions pymc4/inference/sampling/sample.py
@@ -0,0 +1,31 @@
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp


def sample(model,
num_results=5000,
num_burnin_steps=3000,
step_size=.4,
num_leapfrog_steps=3,
numpy=True):
initial_state = []
for name, shape in model.unobserved.iteritems():
initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name)))

states, kernel_results = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=model.target_log_prob_fn(),
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps))

if numpy:
with tf.Session() as sess:
states, is_accepted_ = sess.run([states, kernel_results.is_accepted])
accepted = np.sum(is_accepted_)
print("Acceptance rate: {}".format(accepted / num_results))
return dict(zip(model.unobserved.keys(), states))

51 changes: 0 additions & 51 deletions pymc4/model.py

This file was deleted.

4 changes: 4 additions & 0 deletions pymc4/model/__init__.py
@@ -0,0 +1,4 @@
from .base import (
Model,
inline
)
99 changes: 99 additions & 0 deletions pymc4/model/base.py
@@ -0,0 +1,99 @@
import biwrap
import tensorflow as tf
from tensorflow_probability import edward2 as ed
from pymc4.util import interceptors

__all__ = ['Model', 'inline']


class Config(dict):
def __getattr__(self, item):
try:
return self.__getitem__(item)
except KeyError as e:
error = KeyError(item, '"{i}" is not found in configuration for the model, '
'you probably need to pass "{i}" in model definition as '
'\n`model = pm.Model(..., {i}=value)`'
'\nor'
'\n'
'\n@pm.inline(..., {i}=value)'
'\ndef model(cfg):'
'\n # your model starts here'
'\n ...'.format(i=item))
raise error from e


class Model(object):
def __init__(self, name=None, graph=None, session=None, **config):
self._cfg = Config(**config)
self.name = name
self._f = None
self._variables = None
self._observed = dict()
if session is None:
session = tf.Session(graph=graph)
self.session = session

def define(self, f):
self._f = f
self._init_variables()
return f

def configure(self, **override):
self._cfg.update(**override)
self._init_variables()
return self

def _init_variables(self):
info_collector = interceptors.CollectVariablesInfo()
with self.graph.as_default(), ed.interception(info_collector):
self._f(self.cfg)
self._variables = info_collector.result

def test_point(self, sample=True):
def not_observed(var, *args, **kwargs):
return kwargs['name'] not in self.observed
values_collector = interceptors.CollectVariables(filter=not_observed)
chain = [values_collector]
if not sample:

def get_mode(state, rv, *args, **kwargs):
return rv.distribution.mode()
chain.insert(0, interceptors.Generic(after=get_mode))

with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)):
self._f(self.cfg)
with self.session.as_default():
returns = self.session.run(list(values_collector.result.values()))
return dict(zip(values_collector.result.keys(), returns))

def observe(self, **observations):
self._observed = observations
return self

def reset(self):
self._observed = dict()
return self

@property
def graph(self):
return self.session.graph

@property
def observed(self):
return self._observed

@property
def variables(self):
return self._variables

@property
def cfg(self):
return self._cfg


@biwrap.biwrap
def inline(f, **kwargs):
model = Model(**kwargs)
model.define(f)
return model
28 changes: 0 additions & 28 deletions pymc4/sampling.py

This file was deleted.

105 changes: 0 additions & 105 deletions pymc4/util/graph.py
Expand Up @@ -24,108 +24,3 @@ def make_shared_vectorized_input(rvs, test_values):
mapping[name] = tf.reshape(vec[j:d], shape, name=name)
j += d
return vec, mapping


# from Theano codebase
def stack_search(start, expand, mode='bfs', build_inv=False):
"""
Search through a graph, either breadth- or depth-first.
Parameters
----------
start : deque
Search from these nodes.
expand : callable
When we get to a node, add expand(node) to the list of nodes to visit.
This function should return a list, or None.
mode : string
'bfs' or 'dfs' for breath first search or depth first search.
Returns
-------
list of `Variable` or `Apply` instances (depends on `expend`)
The list of nodes in order of traversal.
Notes
-----
A node will appear at most once in the return value, even if it
appears multiple times in the start parameter.
:postcondition: every element of start is transferred to the returned list.
:postcondition: start is empty.
"""

if mode not in ('bfs', 'dfs'):
raise ValueError('mode should be bfs or dfs', mode)
rval_set = set()
rval_list = list()
if mode == 'bfs':
start_pop = start.popleft
else:
start_pop = start.pop
expand_inv = {} # var: clients
while start:
l = start_pop()
if id(l) not in rval_set:
rval_list.append(l)
rval_set.add(id(l))
expand_l = expand(l)
if expand_l:
if build_inv:
for r in expand_l:
expand_inv.setdefault(r, []).append(l)
start.extend(expand_l)
assert len(rval_list) == len(rval_set)
if build_inv:
return rval_list, expand_inv
return rval_list


# from Theano codebase
def ancestors(variable_list, blockers=None):
"""
Return the variables that contribute to those in variable_list (inclusive).
Parameters
----------
variable_list : list of `Variable` instances
Output `Variable` instances from which to search backward through
owners.
Returns
-------
list of `Variable` instances
All input nodes, in the order found by a left-recursive depth-first
search started at the nodes in `variable_list`.
"""
def expand(r):
if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs)
dfs_variables = stack_search(collections.deque(variable_list), expand, 'dfs')
return dfs_variables


# from Theano codebase
def inputs(variable_list, blockers=None):
"""
Return the inputs required to compute the given Variables.
Parameters
----------
variable_list : list of `Variable` instances
Output `Variable` instances from which to search backward through
owners.
Returns
-------
list of `Variable` instances
Input nodes with no owner, in the order found by a left-recursive
depth-first search started at the nodes in `variable_list`.
"""
from ..model import InputDistribution
vlist = ancestors(variable_list, blockers)
rval = [r for r in vlist if hasattr(r, 'distribution') and
isinstance(r.distribution, InputDistribution)]
return rval

0 comments on commit d07338e

Please sign in to comment.