Skip to content

Commit 663bfa2

Browse files
Introduce RandomVariable push-out optimization
1 parent e349607 commit 663bfa2

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

symbolic_pymc/theano/opt.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from functools import wraps
77

8-
from theano.gof.opt import LocalOptimizer
8+
from theano.gof.opt import LocalOptimizer, local_optimizer
9+
from theano.scan_module.scan_op import Scan
10+
from theano.scan_module.scan_utils import scan_args as ScanArgs
911

1012
from unification import var, variables
1113

@@ -14,6 +16,7 @@
1416
from etuples.core import ExpressionTuple
1517

1618
from .meta import MetaSymbol
19+
from .ops import RandomVariable
1720

1821

1922
def eval_and_reify_meta(x):
@@ -219,3 +222,50 @@ def transform(self, node):
219222
return new_node
220223
else:
221224
return False
225+
226+
227+
@local_optimizer([Scan])
228+
def push_out_rvs_from_scan(node):
229+
"""Push `RandomVariable`s out of `Scan` nodes.
230+
231+
When `RandomVariable`s are created within the inner-graph of a `Scan` and
232+
are not output to the outer-graph, we "push" them out of the inner-graph.
233+
This helps us produce an outer-graph in which all the relevant `RandomVariable`s
234+
are accessible (e.g. for constructing a log-likelihood graph).
235+
"""
236+
scan_args = ScanArgs(node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info)
237+
238+
# Find the un-output `RandomVariable`s created in the inner-graph
239+
clients = {}
240+
local_fgraph_topo = theano.gof.graph.io_toposort(
241+
scan_args.inner_inputs, scan_args.inner_outputs, clients=clients
242+
)
243+
unpushed_inner_rvs = []
244+
for n in local_fgraph_topo:
245+
if isinstance(n.op, RandomVariable):
246+
unpushed_inner_rvs.extend([c for c in clients[n] if c not in scan_args.inner_outputs])
247+
248+
if len(unpushed_inner_rvs) == 0:
249+
return False
250+
251+
# Add the new outputs to the inner and outer graphs
252+
scan_args.inner_out_nit_sot.extend(unpushed_inner_rvs)
253+
254+
assert len(scan_args.outer_in_nit_sot) > 0, "No outer-graph inputs are nit-sots!"
255+
256+
# Just like `theano.scan`, we simply copy/repeat the existing nit-sot
257+
# outer-graph input value, which represents the actual size of the output
258+
# tensors. Apparently, the value needs to be duplicated for all nit-sots.
259+
# FYI: This is what increments the nit-sot values in `scan_args.info`, as
260+
# well.
261+
# TODO: Can we just use `scan_args.n_steps`?
262+
scan_args.outer_in_nit_sot.extend(scan_args.outer_in_nit_sot[0:1] * len(unpushed_inner_rvs))
263+
264+
op = Scan(scan_args.inner_inputs, scan_args.inner_outputs, scan_args.info)
265+
outputs = list(op(*scan_args.outer_inputs))
266+
267+
# Return only the replacements for the original `node.outputs`
268+
new_inner_out_idx = [scan_args.inner_outputs.index(i) for i in unpushed_inner_rvs]
269+
_ = [outputs.pop(op.var_mappings["outer_out_from_inner_out"][i]) for i in new_inner_out_idx]
270+
271+
return dict(zip(node.outputs, outputs))

tests/theano/test_opt.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
import theano
13
import theano.tensor as tt
24

35
from unification import var
@@ -9,10 +11,16 @@
911

1012
from theano.gof.opt import EquilibriumOptimizer
1113
from theano.gof.graph import inputs as tt_inputs
14+
from theano.scan_module.scan_op import Scan
1215

1316
from symbolic_pymc.theano.meta import mt
14-
from symbolic_pymc.theano.opt import KanrenRelationSub, FunctionGraph
17+
from symbolic_pymc.theano.opt import (
18+
KanrenRelationSub,
19+
FunctionGraph,
20+
push_out_rvs_from_scan,
21+
)
1522
from symbolic_pymc.theano.utils import optimize_graph
23+
from symbolic_pymc.theano.random_variables import CategoricalRV, DirichletRV, NormalRV
1624

1725

1826
def test_kanren_opt():
@@ -58,3 +66,74 @@ def distributes(in_lv, out_lv):
5866
assert fgraph_opt.owner.inputs[1].owner.op == tt.add
5967
assert isinstance(fgraph_opt.owner.inputs[1].owner.inputs[0].owner.op, tt.Dot)
6068
assert isinstance(fgraph_opt.owner.inputs[1].owner.inputs[1].owner.op, tt.Dot)
69+
70+
71+
def test_push_out_rvs():
72+
theano.config.cxx = ""
73+
theano.config.mode = "FAST_COMPILE"
74+
tt.config.compute_test_value = "warn"
75+
76+
rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
77+
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
78+
rng_tt.tag.is_rng = True
79+
rng_tt.default_update = rng_tt
80+
81+
N_tt = tt.iscalar("N")
82+
N_tt.tag.test_value = 10
83+
M_tt = tt.iscalar("M")
84+
M_tt.tag.test_value = 2
85+
86+
mus_tt = tt.matrix("mus_t")
87+
mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
88+
theano.config.floatX
89+
)
90+
91+
sigmas_tt = tt.ones((N_tt,))
92+
Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")
93+
94+
# The optimizer should do nothing to this term, because it's not a `Scan`
95+
fgraph = FunctionGraph(tt_inputs([Gamma_rv]), [Gamma_rv])
96+
pushoutrvs_opt = EquilibriumOptimizer([push_out_rvs_from_scan], max_use_ratio=10)
97+
Gamma_opt_rv = optimize_graph(fgraph, pushoutrvs_opt, return_graph=False)
98+
# The `FunctionGraph` will, however, clone the graph objects, so we can't
99+
# simply check that `gamma_opt_rv == Gamma_rv`
100+
assert all(type(a) == type(b) for a, b in zip(tt_inputs([Gamma_rv]), tt_inputs([Gamma_opt_rv])))
101+
assert theano.scan_module.scan_utils.equal_computations(
102+
[Gamma_opt_rv], [Gamma_rv], tt_inputs([Gamma_opt_rv]), tt_inputs([Gamma_rv])
103+
)
104+
105+
# In this case, `Y_t` depends on `S_t` and `S_t` is not output. Our
106+
# push-out optimization should create a new `Scan` that also outputs each
107+
# `S_t`.
108+
def scan_fn(mus_t, sigma_t, Gamma_t, rng):
109+
S_t = CategoricalRV(Gamma_t[0], rng=rng, name="S_t")
110+
Y_t = NormalRV(mus_t[S_t], sigma_t, rng=rng, name="Y_t")
111+
return Y_t
112+
113+
Y_rv, _ = theano.scan(
114+
fn=scan_fn,
115+
sequences=[mus_tt, sigmas_tt],
116+
non_sequences=[Gamma_rv, rng_tt],
117+
outputs_info=[{}],
118+
strict=True,
119+
name="scan_rv",
120+
)
121+
Y_rv.name = "Y_rv"
122+
123+
orig_scan_op = Y_rv.owner.op
124+
assert len(Y_rv.owner.outputs) == 2
125+
assert isinstance(orig_scan_op, Scan)
126+
assert len(orig_scan_op.outputs) == 2
127+
assert orig_scan_op.outputs[0].owner.op == NormalRV
128+
assert isinstance(orig_scan_op.outputs[1].type, tt.raw_random.RandomStateType)
129+
130+
fgraph = FunctionGraph(tt_inputs([Y_rv]), [Y_rv], clone=True)
131+
fgraph_opt = optimize_graph(fgraph, pushoutrvs_opt, return_graph=True)
132+
133+
# There should now be a new output for all the `S_t`
134+
new_scan = fgraph_opt.outputs[0].owner
135+
assert len(new_scan.outputs) == 3
136+
assert isinstance(new_scan.op, Scan)
137+
assert new_scan.op.outputs[0].owner.op == NormalRV
138+
assert new_scan.op.outputs[1].owner.op == CategoricalRV
139+
assert isinstance(new_scan.op.outputs[2].type, tt.raw_random.RandomStateType)

0 commit comments

Comments
 (0)