Skip to content

Commit

Permalink
Improve collect_default_updates
Browse files Browse the repository at this point in the history
* It works with nested RNGs
* It raises error if RNG used in SymbolicRandomVariable is not given an update
* It raises warning if same RNG is used in multiple nodes
  • Loading branch information
ricardoV94 committed Mar 30, 2023
1 parent a75af50 commit 3f2a1da
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 51 deletions.
4 changes: 3 additions & 1 deletion pymc/distributions/distribution.py
Expand Up @@ -593,7 +593,9 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):

def update(self, node: Node):
op = node.op
inner_updates = collect_default_updates(op.inner_inputs, op.inner_outputs)
inner_updates = collect_default_updates(
op.inner_inputs, op.inner_outputs, must_be_shared=False
)

# Map inner updates to outer inputs/outputs
updates = {}
Expand Down
101 changes: 72 additions & 29 deletions pymc/pytensorf.py
Expand Up @@ -42,7 +42,6 @@
Variable,
clone_get_equiv,
graph_inputs,
vars_between,
walk,
)
from pytensor.graph.fg import FunctionGraph
Expand All @@ -51,6 +50,7 @@
from pytensor.tensor.basic import _as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
Expand Down Expand Up @@ -1000,42 +1000,85 @@ def reseed_rngs(


def collect_default_updates(
inputs: Sequence[Variable], outputs: Sequence[Variable]
inputs: Sequence[Variable],
outputs: Sequence[Variable],
must_be_shared: bool = True,
) -> Dict[Variable, Variable]:
"""Collect default update expression of RVs between inputs and outputs"""
"""Collect default update expression for shared-variable RNGs used by RVs between inputs and outputs.
If `must_be_shared` is False, update expressions will also be returned for non-shared input RNGs.
This can be useful to obtain the symbolic update expressions from inner graphs.
"""

# Avoid circular import
from pymc.distributions.distribution import SymbolicRandomVariable

def find_default_update(clients, rng: Variable) -> Union[None, Variable]:
rng_clients = clients.get(rng, None)

# Root case, RNG is not used elsewhere
if not rng_clients:
return rng

if len(rng_clients) > 1:
warnings.warn(
f"RNG Variable {rng} has multiple clients. This is likely an inconsistent random graph.",
UserWarning,
)
return None

[client, _] = rng_clients[0]

# RNG is an output of the function, this is not a problem
if client == "output":
return rng

# RNG is used by another operator, which should output an update for the RNG
if isinstance(client.op, RandomVariable):
# RandomVariable first output is always the update of the input RNG
next_rng = client.outputs[0]

elif isinstance(client.op, SymbolicRandomVariable):
# SymbolicRandomVariable have an explicit method that returns an
# update mapping for their RNG(s)
next_rng = client.op.update(client).get(rng)
if next_rng is None:
raise ValueError(
f"No update mapping found for RNG used in SymbolicRandomVariable Op {client.op}"
)
else:
# We don't know how this RNG should be updated (e.g., Scan).
# The user should provide an update manually
return None

# Recurse until we find final update for RNG
return find_default_update(clients, next_rng)

outputs = makeiter(outputs)
fg = FunctionGraph(outputs=outputs, clone=False)
clients = fg.clients

rng_updates = {}
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
for random_var in (
var
for var in vars_between(inputs, output_to_list)
if var.owner
and isinstance(var.owner.op, (RandomVariable, SymbolicRandomVariable))
and var not in inputs
# Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True`
for input_rng in (
inp
for inp in graph_inputs(outputs, blockers=inputs)
if (
(not must_be_shared or isinstance(inp, SharedVariable))
and isinstance(inp.type, RandomType)
)
):
# All nodes in `vars_between(inputs, outputs)` have owners.
# But mypy doesn't know, so we just assert it:
assert random_var.owner.op is not None
if isinstance(random_var.owner.op, RandomVariable):
rng = random_var.owner.inputs[0]
if getattr(rng, "default_update", None) is not None:
update_map = {rng: rng.default_update}
else:
update_map = {rng: random_var.owner.outputs[0]}
# Even if an explicit default update is provided, we call it to
# issue any warnings about invalid random graphs.
default_update = find_default_update(clients, input_rng)

# Respect default update if provided
if getattr(input_rng, "default_update", None):
rng_updates[input_rng] = input_rng.default_update
else:
update_map = random_var.owner.op.update(random_var.owner)
# Check that we are not setting different update expressions for the same variables
for rng, update in update_map.items():
if rng not in rng_updates:
rng_updates[rng] = update
# When a variable has multiple outputs, it will be called twice with the same
# update expression. We don't want to raise in that case, only if the update
# expression in different from the one already registered
elif rng_updates[rng] is not update:
raise ValueError(f"Multiple update expressions found for the variable {rng}")
if default_update is not None:
rng_updates[input_rng] = default_update

return rng_updates


Expand Down
127 changes: 106 additions & 21 deletions tests/test_pytensorf.py
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from unittest import mock

import numpy as np
Expand Down Expand Up @@ -38,6 +40,7 @@
from pymc.exceptions import NotConstantValueError
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
collect_default_updates,
compile_pymc,
constant_fold,
convert_observed_data,
Expand Down Expand Up @@ -406,28 +409,63 @@ def test_compile_pymc_updates_inputs(self):
# Each RV adds a shared output for its rng
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph

# Disable `reseed_rngs` so that we can test with simpler update rule
@mock.patch("pymc.pytensorf.reseed_rngs")
def test_compile_pymc_custom_update_op(self, _):
"""Test that custom MeasurableVariable Op updates are used by compile_pymc"""
def test_compile_pymc_symbolic_rv_update(self):
"""Test that SymbolicRandomVariable Op update methods are used by compile_pymc"""

class NonSymbolicRV(OpFromGraph):
def update(self, node):
return {node.inputs[0]: node.inputs[0] + 1}
return {node.inputs[0]: node.outputs[0]}

dummy_inputs = [pt.scalar(), pt.scalar()]
dummy_outputs = [pt.add(*dummy_inputs)]
dummy_x = NonSymbolicRV(dummy_inputs, dummy_outputs)(pytensor.shared(1.0), 1.0)
rng = pytensor.shared(np.random.default_rng())
dummy_rng = rng.type()
dummy_next_rng, dummy_x = NonSymbolicRV(
[dummy_rng], pt.random.normal(rng=dummy_rng).owner.outputs
)(rng)

# Check that there are no updates at first
fn = compile_pymc(inputs=[], outputs=dummy_x)
assert fn() == fn() == 2.0
assert fn() == fn()

# And they are enabled once the Op is registered as a SymbolicRV
SymbolicRandomVariable.register(NonSymbolicRV)
fn = compile_pymc(inputs=[], outputs=dummy_x)
assert fn() == 2.0
assert fn() == 3.0
fn = compile_pymc(inputs=[], outputs=dummy_x, random_seed=431)
assert fn() != fn()

def test_compile_pymc_symbolic_rv_missing_update(self):
"""Test that error is raised if SymbolicRandomVariable Op does not
provide rule for updating RNG"""

class SymbolicRV(OpFromGraph):
def update(self, node):
# Update is provided for rng1 but not rng2
return {node.inputs[0]: node.outputs[0]}

SymbolicRandomVariable.register(SymbolicRV)

# No problems at first, as the one RNG is given the update rule
rng1 = pytensor.shared(np.random.default_rng())
dummy_rng1 = rng1.type()
dummy_next_rng1, dummy_x1 = SymbolicRV(
[dummy_rng1],
pt.random.normal(rng=dummy_rng1).owner.outputs,
)(rng1)
fn = compile_pymc(inputs=[], outputs=dummy_x1, random_seed=433)
assert fn() != fn()

# Now there's a problem as there is no update rule for rng2
rng2 = pytensor.shared(np.random.default_rng())
dummy_rng2 = rng2.type()
dummy_next_rng1, dummy_x1, dummy_next_rng2, dummy_x2 = SymbolicRV(
[dummy_rng1, dummy_rng2],
[
*pt.random.normal(rng=dummy_rng1).owner.outputs,
*pt.random.normal(rng=dummy_rng2).owner.outputs,
],
)(rng1, rng2)
with pytest.raises(
ValueError, match="No update mapping found for RNG used in SymbolicRandomVariable"
):
compile_pymc(inputs=[], outputs=[dummy_x1, dummy_x2])

def test_random_seed(self):
seedx = pytensor.shared(np.random.default_rng(1))
Expand Down Expand Up @@ -457,15 +495,62 @@ def test_random_seed(self):
assert y3_eval == y2_eval

def test_multiple_updates_same_variable(self):
rng = pytensor.shared(np.random.default_rng(), name="rng")
x = pt.random.normal(rng=rng)
y = pt.random.normal(rng=rng)

assert compile_pymc([], [x])
assert compile_pymc([], [y])
msg = "Multiple update expressions found for the variable rng"
with pytest.raises(ValueError, match=msg):
compile_pymc([], [x, y])
# Raise if unexpected warning is issued
with warnings.catch_warnings():
warnings.simplefilter("error")

rng = pytensor.shared(np.random.default_rng(), name="rng")
x = pt.random.normal(rng=rng)
y = pt.random.normal(rng=rng)

# No warnings if only one variable is used
assert compile_pymc([], [x])
assert compile_pymc([], [y])

user_warn_msg = "RNG Variable rng has multiple clients"
with pytest.warns(UserWarning, match=user_warn_msg):
f = compile_pymc([], [x, y], random_seed=456)
assert f() == f()

# The user can provide an explicit update, but we will still issue a warning
with pytest.warns(UserWarning, match=user_warn_msg):
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
assert f() != f()

# Same with default update
rng.default_update = x.owner.outputs[0]
with pytest.warns(UserWarning, match=user_warn_msg):
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
assert f() != f()

def test_nested_updates(self):
rng = pytensor.shared(np.random.default_rng())
next_rng1, x = pt.random.normal(rng=rng).owner.outputs
next_rng2, y = pt.random.normal(rng=next_rng1).owner.outputs
next_rng3, z = pt.random.normal(rng=next_rng2).owner.outputs

collect_default_updates([], [x, y, z]) == {rng: next_rng3}

fn = compile_pymc([], [x, y, z], random_seed=514)
assert not set(list(np.array(fn()))) & set(list(np.array(fn())))

# A local myopic rule (as PyMC used before, would not work properly)
fn = pytensor.function([], [x, y, z], updates={rng: next_rng1})
assert set(list(np.array(fn()))) & set(list(np.array(fn())))


def test_collect_default_updates_must_be_shared():
shared_rng = pytensor.shared(np.random.default_rng())
nonshared_rng = shared_rng.type()

next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs

res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
assert res == {shared_rng: next_rng_of_shared}

res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}


def test_replace_rng_nodes():
Expand Down

0 comments on commit 3f2a1da

Please sign in to comment.