Skip to content

Commit

Permalink
Merge ee98816 into 9e0f260
Browse files Browse the repository at this point in the history
  • Loading branch information
rpgoldman committed Jul 16, 2019
2 parents 9e0f260 + ee98816 commit ba03900
Show file tree
Hide file tree
Showing 2 changed files with 373 additions and 24 deletions.
363 changes: 363 additions & 0 deletions pymc3/distributions/posterior_predictive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,363 @@
import numbers
import itertools
from typing import List, Dict, Any, Optional, Set, Tuple
import logging

import numpy as np
import theano
import theano.tensor as tt

from ..backends.base import MultiTrace
from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function
from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext



PosteriorPredictiveTrace = Dict[str, np.ndarray]

def posterior_predictive_draw_values(vars: List[Any], trace: MultiTrace, samples: int, size: Optional[int] = None) -> PosteriorPredictiveTrace:
if size is not None:
raise NotImplementedError("size is not yet implemented for sample_posterior_predictive")

sampler = _PosteriorPredictiveSampler(vars, trace, samples, None)
sampler.draw_values()
return sampler.pp_trace


class _PosteriorPredictiveSampler():
'''The process of posterior predictive sampling is quite complicated so this provides a central data store.'''
pp_trace: PosteriorPredictiveTrace

# inputs
vars: List[Any]
trace: MultiTrace
samples: int
size: Optional[int] # not supported!

# other slots
logger: logging.Logger

# for the search
evaluated: Dict[int, np.ndarray]
symbolic_params: List[Tuple[int, Any]]

# set by make_graph...
leaf_nodes: Dict[str, Any]
named_nodes_parents: Dict[str, Any]
named_nodes_children: Dict[str, Any]

def __init__(self, vars, trace, samples, size):
"""
"""
self.vars = vars
self.trace = trace
self.samples = samples
self.size = size
self.pp_trace = {}
self.logger = logging.getLogger('posterior_predictive')


def draw_values(self) -> None:
vars = self.vars
trace = self.trace
samples = self.samples
size = self.size

self.context = context = _DrawValuesContext()
self.pp_trace = {} # type: PosteriorPredictiveTrace
with context:
self.init()
self.make_graph()
with context:
drawn = context.drawn_vars
# Init givens and the stack of nodes to try to `_draw_value` from
givens = {p.name: (p, v) for (p, size), v in drawn.items()
if getattr(p, 'name', None) is not None}
stack = list(self.leaf_nodes.values()) # A queue would be more appropriate
while stack:
next_ = stack.pop(0)
if (next_, size) in drawn:
# If the node already has a givens value, skip it
continue
elif isinstance(next_, (tt.TensorConstant,
tt.sharedvar.SharedVariable)):
# If the node is a theano.tensor.TensorConstant or a
# theano.tensor.sharedvar.SharedVariable, its value will be
# available automatically in _compile_theano_function so
# we can skip it. Furthermore, if this node was treated as a
# TensorVariable that should be compiled by theano in
# _compile_theano_function, it would raise a `TypeError:
# ('Constants not allowed in param list', ...)` for
# TensorConstant, and a `TypeError: Cannot use a shared
# variable (...) as explicit input` for SharedVariable.
# ObservedRV and MultiObservedRV instances are ViewOPs
# of TensorConstants or SharedVariables, we must add them
# to the stack or risk evaluating deterministics with the
# wrong values (issue #3354)
stack.extend([node for node in self.named_nodes_parents[next_]
if isinstance(node, (ObservedRV,
MultiObservedRV))
and (node, size) not in drawn])
continue
else:
# If the node does not have a givens value, try to draw it.
# The named node's children givens values must also be taken
# into account.
children = self.named_nodes_children[next_]
temp_givens = [givens[k] for k in givens if k in children]
try:
# This may fail for autotransformed RVs, which don't
# have the random method
value = self.draw_value(next_,
givens=temp_givens)
assert isinstance(value, np.ndarray)
self.pp_trace[next_.name] = value
givens[next_.name] = (next_, value)
drawn[(next_, size)] = value
except theano.gof.fg.MissingInputError:
# The node failed, so we must add the node's parents to
# the stack of nodes to try to draw from. We exclude the
# nodes in the `params` list.
stack.extend([node for node in self.named_nodes_parents[next_]
if node is not None and
(node, size) not in drawn])


# the below makes sure the graph is evaluated in order
# test_distributions_random::TestDrawValues::test_draw_order fails without it
# The remaining params that must be drawn are all hashable
to_eval = set() # type: Set[int]
missing_inputs = set([j for j, p in self.symbolic_params]) # type: Set[int]
while to_eval or missing_inputs:
if to_eval == missing_inputs:
raise ValueError('Cannot resolve inputs for {}'.format([str(trace.varnames[j]) for j in to_eval]))
to_eval = set(missing_inputs)
missing_inputs = set()
for param_idx in to_eval:
param = vars[param_idx]
drawn = context.drawn_vars
if (param, size) in drawn:
self.evaluated[param_idx] = drawn[(param, size)]
else:
try:
if param in self.named_nodes_children:
for node in self.named_nodes_children[param]:
if (
node.name not in givens and
(node, size) in drawn
):
givens[node.name] = (
node,
drawn[(node, size)]
)
value = self.draw_value(param,
givens=givens.values())
assert isinstance(value, np.ndarray)
self.pp_trace[param.name] = value
self.evaluated[param_idx] = drawn[(param, size)] = value
givens[param.name] = (param, value)
except theano.gof.fg.MissingInputError:
missing_inputs.add(param_idx)
assert set(self.pp_trace.keys()) == {var.name for var in vars}

def init(self) -> None:
vars: List[Any] = self.vars
trace: MultiTrace = self.trace
samples: int = self.samples
size: Optional[int] = self.size

# initialization phase
context = _DrawValuesContext.get_context()
with context:
drawn = context.drawn_vars
evaluated = {} # type: Dict[int, Any]
symbolic_params = []
for i, var in enumerate(vars):
if is_fast_drawable(var):
evaluated[i] = self.pp_trace[var.name] = self.draw_value(var)
continue
name = getattr(var, 'name', None)
if (var, size) in drawn:
evaluated[i] = val = drawn[(var, size)]
# We filter out Deterministics by checking for `model` attribute
elif name is not None and hasattr(var, 'model') and name in trace.varnames:
# param.name is in point
evaluated[i] = drawn[(var, size)] = trace[name].reshape(-1)
else:
# param still needs to be drawn
symbolic_params.append((i, var))
self.evaluated = evaluated
self.symbolic_params = symbolic_params



def make_graph(self) -> None:
# Distribution parameters may be nodes which have named node-inputs
# specified in the point. Need to find the node-inputs, their
# parents and children to replace them.
symbolic_params = self.symbolic_params
self.leaf_nodes = {} # type: Dict[str, Any]
self.named_nodes_parents = {} # type: Dict[str, Any]
self.named_nodes_children = {} # type: Dict[str, Any]
for _, param in symbolic_params:
if hasattr(param, 'name'):
# Get the named nodes under the `param` node
nn, nnp, nnc = get_named_nodes_and_relations(param)
self.leaf_nodes.update(nn)
# Update the discovered parental relationships
for k in nnp.keys():
if k not in self.named_nodes_parents.keys():
self.named_nodes_parents[k] = nnp[k]
else:
self.named_nodes_parents[k].update(nnp[k])
# Update the discovered child relationships
for k in nnc.keys():
if k not in self.named_nodes_children.keys():
self.named_nodes_children[k] = nnc[k]
else:
self.named_nodes_children[k].update(nnc[k])

def draw_value(self, param, givens = None):
"""Draw a random value from a distribution or return a constant.
Parameters
----------
param : number, array like, theano variable or pymc3 random variable
The value or distribution. Constants or shared variables
will be converted to an array and returned. Theano variables
are evaluated. If `param` is a pymc3 random variables, draw
a new value from it and return that, unless a value is specified
in `point`.
point : dict, optional
A dictionary from pymc3 variable names to their values.
givens : dict, optional
A dictionary from theano variables to their values. These values
are used to evaluate `param` if it is a theano variable.
samples : int
Total number of samples.
size : int, optional
Number of samples per point. Currently not supported.
"""
trace = self.trace
size = self.size
samples = self.samples
if isinstance(param, (numbers.Number, np.ndarray)):
return param
elif isinstance(param, tt.TensorConstant):
return param.value
elif isinstance(param, tt.sharedvar.SharedVariable):
return param.get_value()
elif isinstance(param, (tt.TensorVariable, MultiObservedRV)):
if hasattr(param, 'model') and param.name in trace.varnames:
return trace[param.name]
elif hasattr(param, 'random') and param.random is not None:
model = modelcontext(None)
shape = tuple(_param_shape(param, model)) # type: Tuple[int, ...]
val = np.ndarray((samples, ) + shape)
for i, point in zip(range(samples), itertools.cycle(trace.points())):
val[i,:] = x = param.random(point=point)
assert shape == x.shape
return val
elif (hasattr(param, 'distribution') and
hasattr(param.distribution, 'random') and
param.distribution.random is not None):
if hasattr(param, 'observations'):
# shape inspection for ObservedRV
dist_tmp = param.distribution
try:
distshape = tuple(param.observations.shape.eval()) # type: Tuple[int, ...]
except AttributeError:
distshape = tuple(param.observations.shape)

dist_tmp.shape = distshape
val = np.ndarray((samples, ) + distshape)

try:
for i, point in zip(range(samples), itertools.cycle(trace.points())):
x = val[i,:] = dist_tmp.random(point=point, size=size)
assert x.shape == distshape
return val
except (ValueError, TypeError):
# reset shape to account for shape changes
# with theano.shared inputs
dist_tmp.shape = np.array([])
# We want to draw values to infer the dist_shape,
# we don't want to store these drawn values to the context
with _DrawValuesContextBlocker():
point = next(trace.points())
temp_val = np.atleast_1d(dist_tmp.random(point=point,
size=None))
# Sometimes point may change the size of val but not the
# distribution's shape
if point and size is not None:
temp_size = np.atleast_1d(size)
if all(temp_val.shape[:len(temp_size)] == temp_size):
dist_tmp.shape = temp_val.shape[len(temp_size):]
else:
dist_tmp.shape = temp_val.shape
val = np.ndarray((samples,) + tuple(dist_tmp.shape))
for i, point in zip(range(samples), itertools.cycle(trace.points())):
x = val[i,:] = dist_tmp.random(point=point, size=size)
assert x.shape == tuple(dist_tmp.shape)
return val
else: # has a distribution, but no observations
distshape = tuple(param.distribution.shape)
val = np.ndarray((samples, ) + distshape)
for i, point in zip(range(samples), itertools.cycle(trace.points())):
x = val[i, :] = param.distribution.random(point=point, size=size)
assert x.shape == distshape
return val
else: # doesn't have a distribution -- deterministic (?) -- this is probably wrong, but
# the values here will be a list of *sampled* values -- 1 per sample.
if givens:
variables, values = list(zip(*givens))
else:
variables = values = []
# We only truly care if the ancestors of param that were given
# value have the matching dshape and val.shape
param_ancestors = \
set(theano.gof.graph.ancestors([param],
blockers=list(variables))
)
inputs = [(var, val) for var, val in
zip(variables, values)
if var in param_ancestors]
if inputs:
input_vars, input_vals = list(zip(*inputs))
else:
input_vars = []
input_vals = []
func = _compile_theano_function(param, input_vars)
if not input_vars:
output = func(*input_vals)
if hasattr(output, 'shape', None):
val = np.ndarray(output * samples)
else:
val = np.full(samples, output)

else:
val = np.ndarray([func(*input_vals) for inp in zip(*input_vals)])
return val
raise ValueError('Unexpected type in draw_value: %s' % type(param))


def _param_shape(var_desig, model: Model) -> Tuple[int, ...]:
if isinstance(var_desig, str):
v = model[var_desig]
else:
v = var_desig
if hasattr(v, 'observations'):
try:
# To get shape of _observed_ data container `pm.Data`
# (wrapper for theano.SharedVariable) we evaluate it.
shape = tuple(v.observations.shape.eval())
except AttributeError:
shape = v.observations.shape
elif hasattr(v, 'dshape'):
shape = v.dshape
else:
shape = v.tag.test_value.shape
if shape == (1,):
shape = tuple()
return shape

0 comments on commit ba03900

Please sign in to comment.