Skip to content

Commit

Permalink
Make scan helper return sequences to match old API
Browse files Browse the repository at this point in the history
This was not possible prior to use of TypedListType for non TensorVariable sequences, as it would otherwise not be possible to represent indexing of last sequence state, which is needed e.g., for shared random generator updates.
  • Loading branch information
ricardoV94 committed Jan 16, 2023
1 parent a750fd7 commit 5f15c5e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 32 deletions.
35 changes: 17 additions & 18 deletions pytensor/loop/basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import List, Tuple
from typing import List, Union

import numpy as np

Expand All @@ -18,7 +18,7 @@ def scan(
non_sequences=None,
n_steps=None,
go_backwards=False,
) -> Tuple[List[Variable], List[Variable]]:
) -> Union[Variable, List[Variable]]:
if sequences is None and n_steps is None:
raise ValueError("Must provide n_steps when scanning without sequences")

Expand Down Expand Up @@ -126,10 +126,11 @@ def scan(
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
)
assert isinstance(scan_outs, list)
last_states = scan_outs[: scan_op.n_states]
traces = scan_outs[scan_op.n_states :]
# Don't return the inner index state
return last_states[1:], traces[1:]
# Don't return the last states or the trace for the inner index
traces = scan_outs[scan_op.n_states + 1 :]
if len(traces) == 1:
return traces[0]
return traces


def map(
Expand All @@ -138,14 +139,12 @@ def map(
non_sequences=None,
go_backwards=False,
):
_, traces = scan(
traces = scan(
fn=fn,
sequences=sequences,
non_sequences=non_sequences,
go_backwards=go_backwards,
)
if len(traces) == 1:
return traces[0]
return traces


Expand All @@ -156,16 +155,16 @@ def reduce(
non_sequences=None,
go_backwards=False,
):
final_states, _ = scan(
traces = scan(
fn=fn,
init_states=init_states,
sequences=sequences,
non_sequences=non_sequences,
go_backwards=go_backwards,
)
if len(final_states) == 1:
return final_states[0]
return final_states
if not isinstance(traces, list):
return traces[-1]
return [trace[-1] for trace in traces]


def filter(
Expand All @@ -177,21 +176,21 @@ def filter(
if not isinstance(sequences, (tuple, list)):
sequences = [sequences]

_, masks = scan(
masks = scan(
fn=fn,
sequences=sequences,
non_sequences=non_sequences,
go_backwards=go_backwards,
)

if not all(mask.dtype == "bool" for mask in masks):
raise TypeError("The output of filter fn should be a boolean variable")
if len(masks) == 1:
masks = [masks[0]] * len(sequences)
if not isinstance(masks, list):
masks = [masks] * len(sequences)
elif len(masks) != len(sequences):
raise ValueError(
"filter fn must return one variable or len(sequences), but it returned {len(masks)}"
)
if not all(mask.dtype == "bool" for mask in masks):
raise TypeError("The output of filter fn should be a boolean variable")

filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)]

Expand Down
15 changes: 8 additions & 7 deletions tests/link/jax/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

def test_scan_with_single_sequence():
xs = vector("xs")
_, [ys] = scan(lambda x: x * 100, sequences=[xs])
ys = scan(lambda x: x * 100, sequences=[xs])

out_fg = FunctionGraph([xs], [ys])
compare_jax_and_py(out_fg, [np.arange(10)])


def test_scan_with_single_sequence_shortened_by_nsteps():
xs = vector("xs", shape=(10,)) # JAX needs the length to be constant
_, [ys] = scan(
ys = scan(
lambda x: x * 100,
sequences=[xs],
n_steps=9,
Expand All @@ -35,7 +35,7 @@ def test_scan_with_multiple_sequences():
# JAX can only handle constant n_steps
xs = vector("xs", shape=(10,))
ys = vector("ys", shape=(10,))
_, [zs] = scan(
zs = scan(
fn=lambda x, y: x * y,
sequences=[xs, ys],
)
Expand All @@ -48,7 +48,7 @@ def test_scan_with_multiple_sequences():

def test_scan_with_carried_and_non_carried_states():
x = scalar("x")
_, [ys1, ys2] = scan(
[ys1, ys2] = scan(
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
init_states=[x, None],
n_steps=10,
Expand All @@ -59,7 +59,7 @@ def test_scan_with_carried_and_non_carried_states():

def test_scan_with_sequence_and_carried_state():
xs = vector("xs")
_, [ys] = scan(
ys = scan(
fn=lambda x, ytm1: (ytm1 + 1) * x,
init_states=[zeros(())],
sequences=[xs],
Expand All @@ -71,11 +71,12 @@ def test_scan_with_sequence_and_carried_state():
def test_scan_with_rvs():
rng = shared(np.random.default_rng(123))

[final_rng, _], [rngs, xs] = scan(
[rngs, xs] = scan(
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs,
init_states=[rng, None],
n_steps=10,
)
final_rng = rngs[-1]

# First without updates
fn = function([], xs, mode="JAX", updates=None)
Expand All @@ -99,7 +100,7 @@ def test_scan_with_rvs():


def test_while_scan_fails():
_, [xs] = scan(
xs = scan(
fn=lambda x: (x + 1, until((x + 1) >= 9)),
init_states=[-1],
n_steps=20,
Expand Down
35 changes: 28 additions & 7 deletions tests/loop/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import numpy as np

import pytensor
from pytensor import function, grad
from pytensor import function, grad, shared
from pytensor.loop.basic import filter, map, reduce, scan
from pytensor.scan import until
from pytensor.tensor import arange, eq, scalar, vector, zeros
from pytensor.tensor.random import normal


def test_scan_with_sequences():
xs = vector("xs")
ys = vector("ys")
_, [zs] = scan(
zs = scan(
fn=lambda x, y: x * y,
sequences=[xs, ys],
)
Expand All @@ -23,7 +24,7 @@ def test_scan_with_sequences():

def test_scan_with_carried_and_non_carried_states():
x = scalar("x")
_, [ys1, ys2] = scan(
[ys1, ys2] = scan(
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
init_states=[x, None],
n_steps=10,
Expand All @@ -36,7 +37,7 @@ def test_scan_with_carried_and_non_carried_states():

def test_scan_with_sequence_and_carried_state():
xs = vector("xs")
_, [ys] = scan(
ys = scan(
fn=lambda x, ytm1: (ytm1 + 1) * x,
init_states=[zeros(())],
sequences=[xs],
Expand All @@ -50,7 +51,7 @@ def test_scan_taking_grads_wrt_non_sequence():
xs = vector("xs")
ys = xs**2

_, [J] = scan(
J = scan(
lambda i, ys, xs: grad(ys[i], wrt=xs),
sequences=arange(ys.shape[0]),
non_sequences=[ys, xs],
Expand All @@ -65,7 +66,7 @@ def test_scan_taking_grads_wrt_sequence():
xs = vector("xs")
ys = xs**2

_, [J] = scan(
J = scan(
lambda y, xs: grad(y, wrt=xs),
sequences=[ys],
non_sequences=[xs],
Expand All @@ -76,7 +77,7 @@ def test_scan_taking_grads_wrt_sequence():


def test_while_scan():
_, [xs] = scan(
xs = scan(
fn=lambda x: (x + 1, until((x + 1) >= 9)),
init_states=[-1],
n_steps=20,
Expand All @@ -86,6 +87,26 @@ def test_while_scan():
np.testing.assert_array_equal(f(), np.arange(10))


def test_scan_rvs():
rng = shared(np.random.default_rng(123))
test_rng = np.random.default_rng(123)

def normal_fn(prev_rng):
next_rng, x = normal(rng=prev_rng).owner.outputs
return next_rng, x

[rngs, xs] = scan(
fn=normal_fn,
init_states=[rng, None],
n_steps=5,
)
fn = function([], xs, updates={rng: rngs[-1]})

for i in range(3):
res = fn()
np.testing.assert_almost_equal(res, test_rng.normal(size=5))


def test_map():
xs = vector("xs")
ys = map(
Expand Down

0 comments on commit 5f15c5e

Please sign in to comment.