478 changes: 478 additions & 0 deletions pytensor/loop/op.py

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions pytensor/typed_list/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def __getitem__(self, index):
def __len__(self):
return length(self)

def __bool__(self):
# Truthiness of typedLists cannot depend on length,
# just like truthiness of TensorVariables does not depend on size or contents
return True

def append(self, toAppend):
return append(self, toAppend)

Expand Down Expand Up @@ -677,3 +682,18 @@ def perform(self, node, inputs, outputs):
All PyTensor variables must have the same type.
"""


class MakeEmptyList(Op):
__props__ = ()

def make_node(self, ttype):
tl = TypedListType(ttype)()
return Apply(self, [], [tl])

def perform(self, node, inputs, outputs):
(out,) = outputs
out[0] = []


make_empty_list = MakeEmptyList()
10 changes: 4 additions & 6 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
import pytest

from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, scalar, vector

Expand All @@ -27,9 +25,9 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax")


opts = RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
jax_mode = get_mode("JAX")
py_mode = get_mode("FAST_COMPILE")


def compare_jax_and_py(
Expand Down
114 changes: 114 additions & 0 deletions tests/link/jax/test_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import pytest
from jax.tree_util import tree_leaves

from pytensor import config, function, shared
from pytensor.graph import FunctionGraph
from pytensor.loop.basic import scan
from pytensor.scan import until
from pytensor.tensor import scalar, vector, zeros
from pytensor.tensor.random import normal
from tests.link.jax.test_basic import compare_jax_and_py


def test_scan_with_single_sequence():
xs = vector("xs")
ys = scan(lambda x: x * 100, sequences=[xs])

out_fg = FunctionGraph([xs], [ys])
compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)])


def test_scan_with_single_sequence_shortened_by_nsteps():
xs = vector("xs", shape=(10,)) # JAX needs the length to be constant
ys = scan(
lambda x: x * 100,
sequences=[xs],
n_steps=9,
)

out_fg = FunctionGraph([xs], [ys])
compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)])


def test_scan_with_multiple_sequences():
# JAX can only handle constant n_steps
xs = vector("xs", shape=(10,))
ys = vector("ys", shape=(10,))
zs = scan(
fn=lambda x, y: x * y,
sequences=[xs, ys],
)

out_fg = FunctionGraph([xs, ys], [zs])
compare_jax_and_py(
out_fg, [np.arange(10, dtype=xs.dtype), np.arange(10, dtype=ys.dtype)]
)


def test_scan_with_carried_and_non_carried_states():
x = scalar("x")
[ys1, ys2] = scan(
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
init_states=[x, None],
n_steps=10,
)
out_fg = FunctionGraph([x], [ys1, ys2])
compare_jax_and_py(out_fg, [-1])


def test_scan_with_sequence_and_carried_state():
xs = vector("xs")
ys = scan(
fn=lambda x, ytm1: (ytm1 + 1) * x,
init_states=[zeros(())],
sequences=[xs],
)
out_fg = FunctionGraph([xs], [ys])
compare_jax_and_py(out_fg, [[1, 2, 3]])


def test_scan_with_rvs():
rng = shared(np.random.default_rng(123))

[rngs, xs] = scan(
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs,
init_states=[rng, None],
n_steps=10,
)
final_rng = rngs[-1]

# First without updates
fn = function([], xs, mode="JAX", updates=None)
res1 = fn()
res2 = fn()
assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2)))

# Now with updates
fn = function([], xs, mode="JAX", updates={rng: final_rng})
res1 = fn()
res2 = fn()
assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2)))

# Test traced rngs
fn = function([], [rngs, final_rng], mode="JAX")
rngs_res, final_rng_res = fn()
assert isinstance(rngs_res, list) and len(rngs_res) == 10
assert [np.array(v).tolist() for v in tree_leaves(rngs_res[-1])] == [
np.array(v).tolist() for v in tree_leaves(final_rng_res)
]


def test_while_scan_fails():
xs = scan(
fn=lambda x: (x + 1, until((x + 1) >= 9)),
init_states=[-1],
n_steps=20,
)

out_fg = FunctionGraph([], [xs])
with pytest.raises(
NotImplementedError,
match="Scan ops with while condition cannot be transpiled JAX",
):
compare_jax_and_py(out_fg, [])
Empty file added tests/loop/__init__.py
Empty file.
146 changes: 146 additions & 0 deletions tests/loop/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import numpy as np

import pytensor
from pytensor import config, function, grad, shared
from pytensor.loop.basic import filter, map, reduce, scan
from pytensor.scan import until
from pytensor.tensor import arange, eq, scalar, vector, zeros
from pytensor.tensor.random import normal


def test_scan_with_sequences():
xs = vector("xs")
ys = vector("ys")
zs = scan(
fn=lambda x, y: x * y,
sequences=[xs, ys],
)
pytensor.dprint(ys, print_type=True)
np.testing.assert_almost_equal(
zs.eval(
{
xs: np.arange(10, dtype=config.floatX),
ys: np.arange(10, dtype=config.floatX),
}
),
np.arange(10) ** 2,
)


def test_scan_with_carried_and_non_carried_states():
x = scalar("x")
[ys1, ys2] = scan(
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
init_states=[x, None],
n_steps=10,
)
fn = function([x], [ys1, ys2])
res = fn(-1)
np.testing.assert_almost_equal(res[0], np.arange(10))
np.testing.assert_almost_equal(res[1], np.arange(10) * 2)


def test_scan_with_sequence_and_carried_state():
xs = vector("xs")
ys = scan(
fn=lambda x, ytm1: (ytm1 + 1) * x,
init_states=[zeros(())],
sequences=[xs],
)
fn = function([xs], ys)
np.testing.assert_almost_equal(fn([1, 2, 3]), [1, 4, 15])


def test_scan_taking_grads_wrt_non_sequence():
# Tests sequence + non-carried state
xs = vector("xs")
ys = xs**2

J = scan(
lambda i, ys, xs: grad(ys[i], wrt=xs),
sequences=arange(ys.shape[0]),
non_sequences=[ys, xs],
)

f = pytensor.function([xs], J)
np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]])


def test_scan_taking_grads_wrt_sequence():
# This is not possible with the old Scan
xs = vector("xs")
ys = xs**2

J = scan(
lambda y, xs: grad(y, wrt=xs),
sequences=[ys],
non_sequences=[xs],
)

f = pytensor.function([xs], J)
np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]])


def test_while_scan():
xs = scan(
fn=lambda x: (x + 1, until((x + 1) >= 9)),
init_states=[-1],
n_steps=20,
)

f = pytensor.function([], xs)
np.testing.assert_array_equal(f(), np.arange(10))


def test_scan_rvs():
rng = shared(np.random.default_rng(123))
test_rng = np.random.default_rng(123)

def normal_fn(prev_rng):
next_rng, x = normal(rng=prev_rng).owner.outputs
return next_rng, x

[rngs, xs] = scan(
fn=normal_fn,
init_states=[rng, None],
n_steps=5,
)
fn = function([], xs, updates={rng: rngs[-1]})

for i in range(3):
res = fn()
np.testing.assert_almost_equal(res, test_rng.normal(size=5))


def test_map():
xs = vector("xs")
ys = map(
fn=lambda x: x * 100,
sequences=xs,
)
np.testing.assert_almost_equal(
ys.eval({xs: np.arange(10, dtype=config.floatX)}), np.arange(10) * 100
)


def test_reduce():
xs = vector("xs")
y = reduce(
fn=lambda x, acc: acc + x,
init_states=zeros(()),
sequences=xs,
)
np.testing.assert_almost_equal(
y.eval({xs: np.arange(10, dtype=config.floatX)}), np.arange(10).cumsum()[-1]
)


def test_filter():
xs = vector("xs")
ys = filter(
fn=lambda x: eq(x % 2, 0),
sequences=xs,
)
np.testing.assert_array_equal(
ys.eval({xs: np.arange(0, 20, dtype=config.floatX)}), np.arange(0, 20, 2)
)
186 changes: 186 additions & 0 deletions tests/loop/test_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import numpy as np

import pytensor
from pytensor import config, function, shared
from pytensor.compile import DeepCopyOp
from pytensor.graph import FunctionGraph
from pytensor.graph.rewriting.basic import in2out
from pytensor.loop.op import Loop, Scan, scan_view_last_state
from pytensor.tensor import constant, empty, lscalar, scalar, vector
from pytensor.tensor.random import normal
from pytensor.tensor.random.type import RandomGeneratorType
from pytensor.tensor.subtensor import Subtensor
from pytensor.typed_list import TypedListType


def test_loop_basic():
i = lscalar("i")
x = scalar("x")
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])

loop_op = Loop(update_fg=update_fg)
assert len(loop_op.state_types) == 2
assert len(loop_op.const_types) == 0
_, y = loop_op(np.array(0, dtype="int64"), x)
assert y.eval({x: 0}) == 20


def test_loop_with_constant():
i = lscalar("i")
x = scalar("x")
const = scalar("const")
update_fg = FunctionGraph([i, x, const], [(i + 1) < 10, i + 1, x + const])

loop_op = Loop(update_fg=update_fg)
assert len(loop_op.state_types) == 2
assert len(loop_op.const_types) == 1
_, y = loop_op(np.array(0, dtype="int64"), x, const)
assert y.eval({x: 0, const: 2}) == 20


def test_fori_scan():
x = scalar("x")
update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2])

n_iters = 10
y, ys = Scan(update_fg=update_fg)(n_iters, x)

fn = function([x], [y, ys])

subtensor_nodes = tuple(
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Subtensor)
)
assert len(subtensor_nodes) == 0
loop_nodes = tuple(
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop)
)
assert len(loop_nodes) == 1
(loop_node,) = loop_nodes
assert len(loop_node.outputs) == 3
assert loop_node.outputs[0].type.shape == ()
assert loop_node.outputs[1].type.shape == ()
assert loop_node.outputs[2].type.shape == (10,)

y_eval, ys_eval = fn(0)
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2))
np.testing.assert_array_equal(ys_eval[-1], y_eval)


def test_fori_scan_shape():
x = scalar("x")
update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2])

n_iters = 10
_, ys = Scan(update_fg=update_fg)(n_iters, x)

fn = function([x], ys.shape, on_unused_input="ignore")
nodes = tuple(fn.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op, DeepCopyOp)
assert fn(0) == 10


def test_while_scan():
i = lscalar("i")
x = scalar("x")
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])

max_iters = 1000
_, y, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x)

fn = function([x], [y, ys])

subtensor_nodes = tuple(
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Subtensor)
)
assert len(subtensor_nodes) == 1
loop_nodes = tuple(
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop)
)
assert len(loop_nodes) == 1
(loop_node,) = loop_nodes
assert len(loop_node.outputs) == 3
assert loop_node.outputs[0].type.shape == ()
assert loop_node.outputs[1].type.shape == ()
assert loop_node.outputs[2].type.shape == (1000,)

y_eval, ys_eval = fn(0)
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2))
np.testing.assert_array_equal(ys_eval[-1], y_eval)


def test_while_scan_shape():
i = lscalar("i")
x = scalar("x")
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])

max_iters = 1000
_, _, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x)

fn = function([x], ys.shape)
loop_nodes = tuple(
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop)
)
assert len(loop_nodes) == 1
assert fn(0) == 10


def test_foreach_scan():
idx = scalar("idx", dtype="int64")
dummy_x0 = empty(())
xs = vector("xs")
const = scalar("const")
update_fg = FunctionGraph(
[idx, dummy_x0, xs, const], [constant(np.array(True)), idx + 1, xs[idx] * const]
)

n_steps = xs.shape[0]
_, _, _, ys = Scan(update_fg=update_fg)(n_steps, 0, dummy_x0, xs, const)

fn = pytensor.function([xs, const], ys)

np.testing.assert_almost_equal(
fn(np.arange(10, dtype=config.floatX), 100), np.arange(10) * 100
)


def test_fori_random_scan():
rng_test = np.random.default_rng(123)
rng_shared = shared(np.random.default_rng(123))
n_iters = 5

dummy_init = empty(())
rng = rng_shared.type()
update_fg = FunctionGraph(
[dummy_init, rng],
[constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]],
)

last_y, last_rng, ys, rngs = Scan(update_fg=update_fg)(
n_iters, dummy_init, rng_shared
)
assert isinstance(last_rng.type, RandomGeneratorType)
assert isinstance(rngs.type, TypedListType)
assert isinstance(rngs.type.ttype, RandomGeneratorType)

fn = function([], [ys, rngs], updates={rng_shared: last_rng})
for i in range(2):
ys_res, rngs_res = fn()
for y_res, rng_res in zip(ys_res, rngs_res):
np.testing.assert_almost_equal(y_res, rng_test.normal())
assert rng_res.__getstate__() == rng_test.__getstate__()


def test_scan_view_last_state():
x = scalar("x")
update_fg = FunctionGraph([x], [x > 5, x + 2])

n_iters = 10
y1, ys = Scan(update_fg=update_fg)(n_iters, x)

y2 = ys[-1]
fgraph = FunctionGraph(outputs=[y2, ys], clone=False)
assert fgraph.outputs[0] is not y1
in2out(scan_view_last_state).apply(fgraph)
assert fgraph.outputs[0] is y1
assert fgraph.outputs[1] is ys
4 changes: 4 additions & 0 deletions tests/typed_list/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,7 @@ def test_variable_is_Typed_List_variable(self):
)()

assert isinstance(mySymbolicVariable, TypedListVariable)

def test_any(self):
tlist = TypedListType(TensorType(dtype="int64", shape=(None,)))()
assert any([tlist])