Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ API Reference
api/math
api/data
api/model
api/shape_utils
23 changes: 23 additions & 0 deletions docs/source/api/shape_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
***********
shape_utils
***********

This submodule contains various functions that apply numpy's broadcasting rules to shape tuples, and also to samples drawn from probability distributions.

The main challenge when broadcasting samples drawn from a generative model, is that each random variate has a core shape. When we draw many i.i.d samples from a given RV, for example if we ask for `size_tuple` i.i.d draws, the result usually is a `size_tuple + RV_core_shape`. In the generative model's hierarchy, the downstream RVs that are conditionally dependent on our above sampled values, will get an array with a shape that is incosistent with the core shape they expect to see for their parameters. This is a problem sometimes because it prevents regular broadcasting in complex hierachical models, and thus make prior and posterior predictive sampling difficult.

This module introduces functions that are made aware of the requested `size_tuple` of i.i.d samples, and does the broadcasting on the core shapes, transparently ignoring or moving the i.i.d `size_tuple` prepended axes around.

.. currentmodule:: pymc3.distributions.shape_utils

.. autosummary::

to_tuple
shapes_broadcasting
broadcast_dist_samples_shape
get_broadcastable_dist_samples
broadcast_distribution_samples
broadcast_dist_samples_to

.. automodule:: pymc3.distributions.shape_utils
:members:
22 changes: 21 additions & 1 deletion pymc3/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@


def to_tuple(shape):
"""Convert ints, arrays, and Nones to tuples"""
"""Convert ints, arrays, and Nones to tuples

Parameters
----------
shape : None, int or array-like
Represents the shape to convert to tuple.

Returns
-------
If `shape` is None, returns an empty tuple. If it's an int, (shape,) is
returned. If it is array-like, tuple(shape) is returned.
"""
if shape is None:
return tuple()
temp = np.atleast_1d(shape)
Expand Down Expand Up @@ -106,6 +117,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
Examples
--------
.. code-block:: python

size = 100
shape0 = (size,)
shape1 = (size, 5)
Expand All @@ -115,6 +127,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
assert out == (size, 4, 5)

.. code-block:: python

size = 100
shape0 = (size,)
shape1 = (5,)
Expand All @@ -124,6 +137,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
assert out == (size, 4, 5)

.. code-block:: python

size = 100
shape0 = (1,)
shape1 = (5,)
Expand Down Expand Up @@ -204,6 +218,7 @@ def get_broadcastable_dist_samples(
Examples
--------
.. code-block:: python

must_bcast_with = (3, 1, 5)
size = 100
sample0 = np.random.randn(size)
Expand All @@ -222,6 +237,7 @@ def get_broadcastable_dist_samples(
assert np.all(sample2[:, None] == out[2])

.. code-block:: python

size = 100
must_bcast_with = (3, 1, 5)
sample0 = np.random.randn(size)
Expand Down Expand Up @@ -290,6 +306,7 @@ def broadcast_distribution_samples(samples, size=None):
Examples
--------
.. code-block:: python

size = 100
sample0 = np.random.randn(size)
sample1 = np.random.randn(size, 5)
Expand All @@ -302,6 +319,7 @@ def broadcast_distribution_samples(samples, size=None):
assert np.all(sample2 == out[2])

.. code-block:: python

size = 100
sample0 = np.random.randn(size)
sample1 = np.random.randn(5)
Expand Down Expand Up @@ -335,6 +353,7 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
Examples
--------
.. code-block:: python

to_shape = (3, 1, 5)
size = 100
sample0 = np.random.randn(size)
Expand All @@ -351,6 +370,7 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
assert np.all(sample2[:, None] == out[2])

.. code-block:: python

size = 100
to_shape = (3, 1, 5)
sample0 = np.random.randn(size)
Expand Down