Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,46 @@ def numba_funcify_Scan(op, node, **kwargs):
# Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# shared-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: List[str] = []
inner_in_exprs_scalar: List[str] = []
inner_in_exprs: List[str] = []

def add_inner_in_expr(
outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str]
outer_in_name: str,
tap_offset: Optional[int],
storage_size_var: Optional[str],
vector_slice_opt: bool,
):
"""Construct an inner-input expression."""
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
indexed_inner_in_str = (
storage_name
if tap_offset is None
else idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
if vector_slice_opt:
indexed_inner_in_str_scalar = idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=True
)
temp_storage = f"{storage_name}_temp_scalar_{tap_offset}"
storage_dtype = outer_in_var.type.numpy_dtype.name
temp_scalar_storage_alloc_stmts.append(
f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question so I'm clear. is_vector refers to the values passed by outputs_info. This will be a vector if there are multiple taps requested. You're making this storage scalar as a place to break up this vector and store the individual values as they are fed into the inner function. Doing this prevents the need to call np.asarray on the output, because it's already an array, because these storage values are 0d arrays. Is that right?

Copy link
Member Author

@ricardoV94 ricardoV94 May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The is_vector refers to any vector input (sequence or recurrent), which must be broken into scalar arrays in each iteration. This happens regardless of the number of taps (or you could say there's always at least one tap of -1).

)
inner_in_exprs_scalar.append(
f"{temp_storage}[()] = {indexed_inner_in_str_scalar}"
)
indexed_inner_in_str = temp_storage
else:
indexed_inner_in_str = (
storage_name
if tap_offset is None
else idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
)
)
)
inner_in_exprs.append(indexed_inner_in_str)

for outer_in_name in outer_in_seqs_names:
# These outer-inputs are indexed without offsets or storage wrap-around
add_inner_in_expr(outer_in_name, 0, None)
outer_in_var = outer_in_names_to_vars[outer_in_name]
is_vector = outer_in_var.ndim == 1
add_inner_in_expr(outer_in_name, 0, None, vector_slice_opt=is_vector)

inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict(
zip(
Expand Down Expand Up @@ -190,8 +211,8 @@ def add_output_storage_post_proc_stmt(
output_storage_post_proc_stmts.append(
dedent(
f"""
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if {outer_in_name}_shift > 0:
if (i + {tap_size}) > {storage_size}:
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
Expand Down Expand Up @@ -232,7 +253,13 @@ def add_output_storage_post_proc_stmt(
for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size
assert tap_offset >= 0
add_inner_in_expr(outer_in_name, tap_offset, storage_size_name)
is_vector = outer_in_var.ndim == 1
add_inner_in_expr(
outer_in_name,
tap_offset,
storage_size_name,
vector_slice_opt=is_vector,
)

output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size]
Expand All @@ -253,7 +280,7 @@ def add_output_storage_post_proc_stmt(

else:
storage_size_stmt = ""
add_inner_in_expr(outer_in_name, None, None)
add_inner_in_expr(outer_in_name, None, None, vector_slice_opt=False)
inner_out_to_outer_in_stmts.append(storage_name)

output_idx = outer_output_names.index(storage_name)
Expand Down Expand Up @@ -325,17 +352,19 @@ def add_output_storage_post_proc_stmt(
)

for name in outer_in_non_seqs_names:
add_inner_in_expr(name, None, None)
add_inner_in_expr(name, None, None, vector_slice_opt=False)

if op.info.as_while:
# The inner function will return a boolean as the last value
inner_out_to_outer_in_stmts.append("cond")

assert len(inner_in_exprs) == len(op.fgraph.inputs)

inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar)
inner_in_args = create_arg_string(inner_in_exprs)
inner_outputs = create_tuple_string(inner_output_names)
input_storage_block = "\n".join(storage_alloc_stmts)
input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts)
output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)

Expand All @@ -348,9 +377,13 @@ def scan({", ".join(outer_in_names)}):

{indent(input_storage_block, " " * 4)}

{indent(input_temp_scalar_storage_block, " " * 4)}

i = 0
cond = np.array(False)
while i < n_steps and not cond.item():
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}

{inner_outputs} = scan_inner_func({inner_in_args})
{indent(inner_out_post_processing_block, " " * 8)}
{indent(inner_out_to_outer_out_stmts, " " * 8)}
Expand All @@ -367,8 +400,6 @@ def scan({", ".join(outer_in_names)}):
}
global_env["np"] = np

scalar_op_fn = compile_function_src(
scan_op_src, "scan", {**globals(), **global_env}
)
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})

return numba_basic.numba_njit(scalar_op_fn)
return numba_basic.numba_njit(scan_op_fn)
54 changes: 53 additions & 1 deletion tests/link/numba/test_scan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

import pytensor
import pytensor.tensor as at
from pytensor import config, function, grad
from pytensor.compile.mode import Mode, get_mode
Expand All @@ -9,7 +10,7 @@
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
from pytensor.scan.utils import until
from pytensor.tensor import log, vector
from pytensor.tensor import log, scalar, vector
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import RandomStream
from tests import unittest_tools as utt
Expand Down Expand Up @@ -442,3 +443,54 @@ def test_inner_graph_optimized():
assert isinstance(inner_scan_node.op, Elemwise) and isinstance(
inner_scan_node.op.scalar_op, Log1p
)


def test_vector_taps_benchmark(benchmark):
"""Test vector taps performance.

Vector taps get indexed into numeric types, that must be wrapped back into
scalar arrays. The numba Scan implementation has an optimization to reuse
these scalar arrays instead of allocating them in every iteration.
"""
n_steps = 1000

seq1 = vector("seq1", dtype="float64", shape=(n_steps,))
seq2 = vector("seq2", dtype="float64", shape=(n_steps,))
mitsot_init = vector("mitsot_init", dtype="float64", shape=(2,))
sitsot_init = scalar("sitsot_init", dtype="float64")

def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
mitsot3 = mitsot1 + seq2 + mitsot2 + seq1
sitsot2 = sitsot1 + mitsot3
return mitsot3, sitsot2

outs, _ = scan(
fn=step,
sequences=[seq1, seq2],
outputs_info=[
dict(initial=mitsot_init, taps=[-2, -1]),
dict(initial=sitsot_init, taps=[-1]),
],
)

rng = np.random.default_rng(474)
test = {
seq1: rng.normal(size=n_steps),
seq2: rng.normal(size=n_steps),
mitsot_init: rng.normal(size=(2,)),
sitsot_init: rng.normal(),
}

numba_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("NUMBA"))
scan_nodes = [
node for node in numba_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
assert len(scan_nodes) == 1
numba_res = numba_fn(*test.values())

ref_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("FAST_COMPILE"))
ref_res = ref_fn(*test.values())
for numba_r, ref_r in zip(numba_res, ref_res):
np.testing.assert_array_almost_equal(numba_r, ref_r)

benchmark(numba_fn, *test.values())