Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
restructure + test point implementation
- Loading branch information
Showing
16 changed files
with
234 additions
and
244 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,11 @@ | ||
__version__ = "0.0.1" | ||
from . import model | ||
from . import sampling | ||
from .model import ( | ||
Model, | ||
inline | ||
) | ||
from . import inference | ||
from .inference import ( | ||
sampling | ||
) | ||
from . import util |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import tensorflow as tf | ||
from tensorflow_probability import edward2 as ed | ||
|
||
|
||
__all__ = [ | ||
'Input' | ||
] | ||
|
||
|
||
class InputDistribution(tf.contrib.distributions.Deterministic): | ||
""" | ||
detectable class for input | ||
is input <==> isinstance(rv.distribution, InputDistribution) | ||
""" | ||
|
||
|
||
def Input(name, shape, dtype=None): | ||
return ed.as_random_variable(InputDistribution(name=name, shape=shape, dtype=dtype)) | ||
|
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_probability as tfp | ||
|
||
|
||
def sample(model, | ||
num_results=5000, | ||
num_burnin_steps=3000, | ||
step_size=.4, | ||
num_leapfrog_steps=3, | ||
numpy=True): | ||
initial_state = [] | ||
for name, shape in model.unobserved.iteritems(): | ||
initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name))) | ||
|
||
states, kernel_results = tfp.mcmc.sample_chain( | ||
num_results=num_results, | ||
num_burnin_steps=num_burnin_steps, | ||
current_state=initial_state, | ||
kernel=tfp.mcmc.HamiltonianMonteCarlo( | ||
target_log_prob_fn=model.target_log_prob_fn(), | ||
step_size=step_size, | ||
num_leapfrog_steps=num_leapfrog_steps)) | ||
|
||
if numpy: | ||
with tf.Session() as sess: | ||
states, is_accepted_ = sess.run([states, kernel_results.is_accepted]) | ||
accepted = np.sum(is_accepted_) | ||
print("Acceptance rate: {}".format(accepted / num_results)) | ||
return dict(zip(model.unobserved.keys(), states)) | ||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .base import ( | ||
Model, | ||
inline | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import biwrap | ||
import tensorflow as tf | ||
from tensorflow_probability import edward2 as ed | ||
from pymc4.util import interceptors | ||
|
||
__all__ = ['Model', 'inline'] | ||
|
||
|
||
class Config(dict): | ||
def __getattr__(self, item): | ||
try: | ||
return self.__getitem__(item) | ||
except KeyError as e: | ||
error = KeyError(item, '"{i}" is not found in configuration for the model, ' | ||
'you probably need to pass "{i}" in model definition as ' | ||
'\n`model = pm.Model(..., {i}=value)`' | ||
'\nor' | ||
'\n' | ||
'\n@pm.inline(..., {i}=value)' | ||
'\ndef model(cfg):' | ||
'\n # your model starts here' | ||
'\n ...'.format(i=item)) | ||
raise error from e | ||
|
||
|
||
class Model(object): | ||
def __init__(self, name=None, graph=None, session=None, **config): | ||
self._cfg = Config(**config) | ||
self.name = name | ||
self._f = None | ||
self._variables = None | ||
self._observed = dict() | ||
if session is None: | ||
session = tf.Session(graph=graph) | ||
self.session = session | ||
|
||
def define(self, f): | ||
self._f = f | ||
self._init_variables() | ||
return f | ||
|
||
def configure(self, **override): | ||
self._cfg.update(**override) | ||
self._init_variables() | ||
return self | ||
|
||
def _init_variables(self): | ||
info_collector = interceptors.CollectVariablesInfo() | ||
with self.graph.as_default(), ed.interception(info_collector): | ||
self._f(self.cfg) | ||
self._variables = info_collector.result | ||
|
||
def test_point(self, sample=True): | ||
def not_observed(var, *args, **kwargs): | ||
return kwargs['name'] not in self.observed | ||
values_collector = interceptors.CollectVariables(filter=not_observed) | ||
chain = [values_collector] | ||
if not sample: | ||
|
||
def get_mode(state, rv, *args, **kwargs): | ||
return rv.distribution.mode() | ||
chain.insert(0, interceptors.Generic(after=get_mode)) | ||
|
||
with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)): | ||
self._f(self.cfg) | ||
with self.session.as_default(): | ||
returns = self.session.run(list(values_collector.result.values())) | ||
return dict(zip(values_collector.result.keys(), returns)) | ||
|
||
def observe(self, **observations): | ||
self._observed = observations | ||
return self | ||
|
||
def reset(self): | ||
self._observed = dict() | ||
return self | ||
|
||
@property | ||
def graph(self): | ||
return self.session.graph | ||
|
||
@property | ||
def observed(self): | ||
return self._observed | ||
|
||
@property | ||
def variables(self): | ||
return self._variables | ||
|
||
@property | ||
def cfg(self): | ||
return self._cfg | ||
|
||
|
||
@biwrap.biwrap | ||
def inline(f, **kwargs): | ||
model = Model(**kwargs) | ||
model.define(f) | ||
return model |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.