| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from jax.tree_util import tree_leaves | ||
|
|
||
| from pytensor import config, function, shared | ||
| from pytensor.graph import FunctionGraph | ||
| from pytensor.loop.basic import scan | ||
| from pytensor.scan import until | ||
| from pytensor.tensor import scalar, vector, zeros | ||
| from pytensor.tensor.random import normal | ||
| from tests.link.jax.test_basic import compare_jax_and_py | ||
|
|
||
|
|
||
| def test_scan_with_single_sequence(): | ||
| xs = vector("xs") | ||
| ys = scan(lambda x: x * 100, sequences=[xs]) | ||
|
|
||
| out_fg = FunctionGraph([xs], [ys]) | ||
| compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)]) | ||
|
|
||
|
|
||
| def test_scan_with_single_sequence_shortened_by_nsteps(): | ||
| xs = vector("xs", shape=(10,)) # JAX needs the length to be constant | ||
| ys = scan( | ||
| lambda x: x * 100, | ||
| sequences=[xs], | ||
| n_steps=9, | ||
| ) | ||
|
|
||
| out_fg = FunctionGraph([xs], [ys]) | ||
| compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)]) | ||
|
|
||
|
|
||
| def test_scan_with_multiple_sequences(): | ||
| # JAX can only handle constant n_steps | ||
| xs = vector("xs", shape=(10,)) | ||
| ys = vector("ys", shape=(10,)) | ||
| zs = scan( | ||
| fn=lambda x, y: x * y, | ||
| sequences=[xs, ys], | ||
| ) | ||
|
|
||
| out_fg = FunctionGraph([xs, ys], [zs]) | ||
| compare_jax_and_py( | ||
| out_fg, [np.arange(10, dtype=xs.dtype), np.arange(10, dtype=ys.dtype)] | ||
| ) | ||
|
|
||
|
|
||
| def test_scan_with_carried_and_non_carried_states(): | ||
| x = scalar("x") | ||
| [ys1, ys2] = scan( | ||
| fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), | ||
| init_states=[x, None], | ||
| n_steps=10, | ||
| ) | ||
| out_fg = FunctionGraph([x], [ys1, ys2]) | ||
| compare_jax_and_py(out_fg, [-1]) | ||
|
|
||
|
|
||
| def test_scan_with_sequence_and_carried_state(): | ||
| xs = vector("xs") | ||
| ys = scan( | ||
| fn=lambda x, ytm1: (ytm1 + 1) * x, | ||
| init_states=[zeros(())], | ||
| sequences=[xs], | ||
| ) | ||
| out_fg = FunctionGraph([xs], [ys]) | ||
| compare_jax_and_py(out_fg, [[1, 2, 3]]) | ||
|
|
||
|
|
||
| def test_scan_with_rvs(): | ||
| rng = shared(np.random.default_rng(123)) | ||
|
|
||
| [rngs, xs] = scan( | ||
| fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs, | ||
| init_states=[rng, None], | ||
| n_steps=10, | ||
| ) | ||
| final_rng = rngs[-1] | ||
|
|
||
| # First without updates | ||
| fn = function([], xs, mode="JAX", updates=None) | ||
| res1 = fn() | ||
| res2 = fn() | ||
| assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2))) | ||
|
|
||
| # Now with updates | ||
| fn = function([], xs, mode="JAX", updates={rng: final_rng}) | ||
| res1 = fn() | ||
| res2 = fn() | ||
| assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2))) | ||
|
|
||
| # Test traced rngs | ||
| fn = function([], [rngs, final_rng], mode="JAX") | ||
| rngs_res, final_rng_res = fn() | ||
| assert isinstance(rngs_res, list) and len(rngs_res) == 10 | ||
| assert [np.array(v).tolist() for v in tree_leaves(rngs_res[-1])] == [ | ||
| np.array(v).tolist() for v in tree_leaves(final_rng_res) | ||
| ] | ||
|
|
||
|
|
||
| def test_while_scan_fails(): | ||
| xs = scan( | ||
| fn=lambda x: (x + 1, until((x + 1) >= 9)), | ||
| init_states=[-1], | ||
| n_steps=20, | ||
| ) | ||
|
|
||
| out_fg = FunctionGraph([], [xs]) | ||
| with pytest.raises( | ||
| NotImplementedError, | ||
| match="Scan ops with while condition cannot be transpiled JAX", | ||
| ): | ||
| compare_jax_and_py(out_fg, []) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| import numpy as np | ||
|
|
||
| import pytensor | ||
| from pytensor import config, function, grad, shared | ||
| from pytensor.loop.basic import filter, map, reduce, scan | ||
| from pytensor.scan import until | ||
| from pytensor.tensor import arange, eq, scalar, vector, zeros | ||
| from pytensor.tensor.random import normal | ||
|
|
||
|
|
||
| def test_scan_with_sequences(): | ||
| xs = vector("xs") | ||
| ys = vector("ys") | ||
| zs = scan( | ||
| fn=lambda x, y: x * y, | ||
| sequences=[xs, ys], | ||
| ) | ||
| pytensor.dprint(ys, print_type=True) | ||
| np.testing.assert_almost_equal( | ||
| zs.eval( | ||
| { | ||
| xs: np.arange(10, dtype=config.floatX), | ||
| ys: np.arange(10, dtype=config.floatX), | ||
| } | ||
| ), | ||
| np.arange(10) ** 2, | ||
| ) | ||
|
|
||
|
|
||
| def test_scan_with_carried_and_non_carried_states(): | ||
| x = scalar("x") | ||
| [ys1, ys2] = scan( | ||
| fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), | ||
| init_states=[x, None], | ||
| n_steps=10, | ||
| ) | ||
| fn = function([x], [ys1, ys2]) | ||
| res = fn(-1) | ||
| np.testing.assert_almost_equal(res[0], np.arange(10)) | ||
| np.testing.assert_almost_equal(res[1], np.arange(10) * 2) | ||
|
|
||
|
|
||
| def test_scan_with_sequence_and_carried_state(): | ||
| xs = vector("xs") | ||
| ys = scan( | ||
| fn=lambda x, ytm1: (ytm1 + 1) * x, | ||
| init_states=[zeros(())], | ||
| sequences=[xs], | ||
| ) | ||
| fn = function([xs], ys) | ||
| np.testing.assert_almost_equal(fn([1, 2, 3]), [1, 4, 15]) | ||
|
|
||
|
|
||
| def test_scan_taking_grads_wrt_non_sequence(): | ||
| # Tests sequence + non-carried state | ||
| xs = vector("xs") | ||
| ys = xs**2 | ||
|
|
||
| J = scan( | ||
| lambda i, ys, xs: grad(ys[i], wrt=xs), | ||
| sequences=arange(ys.shape[0]), | ||
| non_sequences=[ys, xs], | ||
| ) | ||
|
|
||
| f = pytensor.function([xs], J) | ||
| np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]]) | ||
|
|
||
|
|
||
| def test_scan_taking_grads_wrt_sequence(): | ||
| # This is not possible with the old Scan | ||
| xs = vector("xs") | ||
| ys = xs**2 | ||
|
|
||
| J = scan( | ||
| lambda y, xs: grad(y, wrt=xs), | ||
| sequences=[ys], | ||
| non_sequences=[xs], | ||
| ) | ||
|
|
||
| f = pytensor.function([xs], J) | ||
| np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]]) | ||
|
|
||
|
|
||
| def test_while_scan(): | ||
| xs = scan( | ||
| fn=lambda x: (x + 1, until((x + 1) >= 9)), | ||
| init_states=[-1], | ||
| n_steps=20, | ||
| ) | ||
|
|
||
| f = pytensor.function([], xs) | ||
| np.testing.assert_array_equal(f(), np.arange(10)) | ||
|
|
||
|
|
||
| def test_scan_rvs(): | ||
| rng = shared(np.random.default_rng(123)) | ||
| test_rng = np.random.default_rng(123) | ||
|
|
||
| def normal_fn(prev_rng): | ||
| next_rng, x = normal(rng=prev_rng).owner.outputs | ||
| return next_rng, x | ||
|
|
||
| [rngs, xs] = scan( | ||
| fn=normal_fn, | ||
| init_states=[rng, None], | ||
| n_steps=5, | ||
| ) | ||
| fn = function([], xs, updates={rng: rngs[-1]}) | ||
|
|
||
| for i in range(3): | ||
| res = fn() | ||
| np.testing.assert_almost_equal(res, test_rng.normal(size=5)) | ||
|
|
||
|
|
||
| def test_map(): | ||
| xs = vector("xs") | ||
| ys = map( | ||
| fn=lambda x: x * 100, | ||
| sequences=xs, | ||
| ) | ||
| np.testing.assert_almost_equal( | ||
| ys.eval({xs: np.arange(10, dtype=config.floatX)}), np.arange(10) * 100 | ||
| ) | ||
|
|
||
|
|
||
| def test_reduce(): | ||
| xs = vector("xs") | ||
| y = reduce( | ||
| fn=lambda x, acc: acc + x, | ||
| init_states=zeros(()), | ||
| sequences=xs, | ||
| ) | ||
| np.testing.assert_almost_equal( | ||
| y.eval({xs: np.arange(10, dtype=config.floatX)}), np.arange(10).cumsum()[-1] | ||
| ) | ||
|
|
||
|
|
||
| def test_filter(): | ||
| xs = vector("xs") | ||
| ys = filter( | ||
| fn=lambda x: eq(x % 2, 0), | ||
| sequences=xs, | ||
| ) | ||
| np.testing.assert_array_equal( | ||
| ys.eval({xs: np.arange(0, 20, dtype=config.floatX)}), np.arange(0, 20, 2) | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| import numpy as np | ||
|
|
||
| import pytensor | ||
| from pytensor import config, function, shared | ||
| from pytensor.compile import DeepCopyOp | ||
| from pytensor.graph import FunctionGraph | ||
| from pytensor.graph.rewriting.basic import in2out | ||
| from pytensor.loop.op import Loop, Scan, scan_view_last_state | ||
| from pytensor.tensor import constant, empty, lscalar, scalar, vector | ||
| from pytensor.tensor.random import normal | ||
| from pytensor.tensor.random.type import RandomGeneratorType | ||
| from pytensor.tensor.subtensor import Subtensor | ||
| from pytensor.typed_list import TypedListType | ||
|
|
||
|
|
||
| def test_loop_basic(): | ||
| i = lscalar("i") | ||
| x = scalar("x") | ||
| update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) | ||
|
|
||
| loop_op = Loop(update_fg=update_fg) | ||
| assert len(loop_op.state_types) == 2 | ||
| assert len(loop_op.const_types) == 0 | ||
| _, y = loop_op(np.array(0, dtype="int64"), x) | ||
| assert y.eval({x: 0}) == 20 | ||
|
|
||
|
|
||
| def test_loop_with_constant(): | ||
| i = lscalar("i") | ||
| x = scalar("x") | ||
| const = scalar("const") | ||
| update_fg = FunctionGraph([i, x, const], [(i + 1) < 10, i + 1, x + const]) | ||
|
|
||
| loop_op = Loop(update_fg=update_fg) | ||
| assert len(loop_op.state_types) == 2 | ||
| assert len(loop_op.const_types) == 1 | ||
| _, y = loop_op(np.array(0, dtype="int64"), x, const) | ||
| assert y.eval({x: 0, const: 2}) == 20 | ||
|
|
||
|
|
||
| def test_fori_scan(): | ||
| x = scalar("x") | ||
| update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) | ||
|
|
||
| n_iters = 10 | ||
| y, ys = Scan(update_fg=update_fg)(n_iters, x) | ||
|
|
||
| fn = function([x], [y, ys]) | ||
|
|
||
| subtensor_nodes = tuple( | ||
| node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Subtensor) | ||
| ) | ||
| assert len(subtensor_nodes) == 0 | ||
| loop_nodes = tuple( | ||
| node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) | ||
| ) | ||
| assert len(loop_nodes) == 1 | ||
| (loop_node,) = loop_nodes | ||
| assert len(loop_node.outputs) == 3 | ||
| assert loop_node.outputs[0].type.shape == () | ||
| assert loop_node.outputs[1].type.shape == () | ||
| assert loop_node.outputs[2].type.shape == (10,) | ||
|
|
||
| y_eval, ys_eval = fn(0) | ||
| np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) | ||
| np.testing.assert_array_equal(ys_eval[-1], y_eval) | ||
|
|
||
|
|
||
| def test_fori_scan_shape(): | ||
| x = scalar("x") | ||
| update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) | ||
|
|
||
| n_iters = 10 | ||
| _, ys = Scan(update_fg=update_fg)(n_iters, x) | ||
|
|
||
| fn = function([x], ys.shape, on_unused_input="ignore") | ||
| nodes = tuple(fn.maker.fgraph.apply_nodes) | ||
| assert len(nodes) == 1 | ||
| assert isinstance(nodes[0].op, DeepCopyOp) | ||
| assert fn(0) == 10 | ||
|
|
||
|
|
||
| def test_while_scan(): | ||
| i = lscalar("i") | ||
| x = scalar("x") | ||
| update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) | ||
|
|
||
| max_iters = 1000 | ||
| _, y, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x) | ||
|
|
||
| fn = function([x], [y, ys]) | ||
|
|
||
| subtensor_nodes = tuple( | ||
| node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Subtensor) | ||
| ) | ||
| assert len(subtensor_nodes) == 1 | ||
| loop_nodes = tuple( | ||
| node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) | ||
| ) | ||
| assert len(loop_nodes) == 1 | ||
| (loop_node,) = loop_nodes | ||
| assert len(loop_node.outputs) == 3 | ||
| assert loop_node.outputs[0].type.shape == () | ||
| assert loop_node.outputs[1].type.shape == () | ||
| assert loop_node.outputs[2].type.shape == (1000,) | ||
|
|
||
| y_eval, ys_eval = fn(0) | ||
| np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) | ||
| np.testing.assert_array_equal(ys_eval[-1], y_eval) | ||
|
|
||
|
|
||
| def test_while_scan_shape(): | ||
| i = lscalar("i") | ||
| x = scalar("x") | ||
| update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) | ||
|
|
||
| max_iters = 1000 | ||
| _, _, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x) | ||
|
|
||
| fn = function([x], ys.shape) | ||
| loop_nodes = tuple( | ||
| node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) | ||
| ) | ||
| assert len(loop_nodes) == 1 | ||
| assert fn(0) == 10 | ||
|
|
||
|
|
||
| def test_foreach_scan(): | ||
| idx = scalar("idx", dtype="int64") | ||
| dummy_x0 = empty(()) | ||
| xs = vector("xs") | ||
| const = scalar("const") | ||
| update_fg = FunctionGraph( | ||
| [idx, dummy_x0, xs, const], [constant(np.array(True)), idx + 1, xs[idx] * const] | ||
| ) | ||
|
|
||
| n_steps = xs.shape[0] | ||
| _, _, _, ys = Scan(update_fg=update_fg)(n_steps, 0, dummy_x0, xs, const) | ||
|
|
||
| fn = pytensor.function([xs, const], ys) | ||
|
|
||
| np.testing.assert_almost_equal( | ||
| fn(np.arange(10, dtype=config.floatX), 100), np.arange(10) * 100 | ||
| ) | ||
|
|
||
|
|
||
| def test_fori_random_scan(): | ||
| rng_test = np.random.default_rng(123) | ||
| rng_shared = shared(np.random.default_rng(123)) | ||
| n_iters = 5 | ||
|
|
||
| dummy_init = empty(()) | ||
| rng = rng_shared.type() | ||
| update_fg = FunctionGraph( | ||
| [dummy_init, rng], | ||
| [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]], | ||
| ) | ||
|
|
||
| last_y, last_rng, ys, rngs = Scan(update_fg=update_fg)( | ||
| n_iters, dummy_init, rng_shared | ||
| ) | ||
| assert isinstance(last_rng.type, RandomGeneratorType) | ||
| assert isinstance(rngs.type, TypedListType) | ||
| assert isinstance(rngs.type.ttype, RandomGeneratorType) | ||
|
|
||
| fn = function([], [ys, rngs], updates={rng_shared: last_rng}) | ||
| for i in range(2): | ||
| ys_res, rngs_res = fn() | ||
| for y_res, rng_res in zip(ys_res, rngs_res): | ||
| np.testing.assert_almost_equal(y_res, rng_test.normal()) | ||
| assert rng_res.__getstate__() == rng_test.__getstate__() | ||
|
|
||
|
|
||
| def test_scan_view_last_state(): | ||
| x = scalar("x") | ||
| update_fg = FunctionGraph([x], [x > 5, x + 2]) | ||
|
|
||
| n_iters = 10 | ||
| y1, ys = Scan(update_fg=update_fg)(n_iters, x) | ||
|
|
||
| y2 = ys[-1] | ||
| fgraph = FunctionGraph(outputs=[y2, ys], clone=False) | ||
| assert fgraph.outputs[0] is not y1 | ||
| in2out(scan_view_last_state).apply(fgraph) | ||
| assert fgraph.outputs[0] is y1 | ||
| assert fgraph.outputs[1] is ys |