-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b8831aa
commit 5b25de8
Showing
4 changed files
with
338 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
|
||
from pytensor import In, Out | ||
from pytensor.compile import optdb, pfunc | ||
from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter | ||
from pytensor.graph.rewriting.basic import in2out | ||
from pytensor.scalar import constant | ||
from pytensor.tensor import NoneConst, and_, empty, scalar, set_subtensor | ||
from pytensor.tensor.type import DenseTensorType, TensorType | ||
from pytensor.tensor.type_other import NoneTypeT | ||
|
||
|
||
def validate_loop_update_types(update): | ||
assert update.outputs[0].type.dtype == "bool" | ||
for input_state, output_state in zip(update.inputs, update.outputs[1:]): | ||
assert input_state.type == output_state.type | ||
|
||
|
||
class Loop(Op): | ||
"""Represent a do-while loop.""" | ||
|
||
def __init__( | ||
self, | ||
update: FunctionGraph, # (*state, *consts) -> (bool, *state) | ||
reverse: Optional[FunctionGraph] = None, | ||
): | ||
validate_loop_update_types(update) | ||
self.state_types = [out.type for out in update.outputs[1:]] | ||
self.const_types = [inp.type for inp in update.inputs[len(self.state_types) :]] | ||
self.update = update | ||
self.reverse = reverse | ||
self._update_fn = None | ||
|
||
@property | ||
def update_fn(self): | ||
"""Lazily compile the inner update function graph.""" | ||
if self._update_fn is not None: | ||
return self._update_fn | ||
|
||
fgraph = self.update | ||
wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs] | ||
wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs] | ||
|
||
self._update_fn = pfunc( | ||
wrapped_inputs, | ||
wrapped_outputs, | ||
mode="FAST_RUN", # TODO: Figure this out | ||
accept_inplace=False, | ||
on_unused_input="ignore", | ||
fgraph=fgraph, | ||
) | ||
return self._update_fn | ||
|
||
def make_node(self, *inputs): | ||
assert len(inputs) == len(self.state_types) + len(self.const_types) | ||
|
||
states = inputs[: len(self.state_types)] | ||
states = [ | ||
inp_type.filter_variable(inp) | ||
for inp_type, inp in zip(self.state_types, states) | ||
] | ||
|
||
consts = inputs[len(self.state_types) :] | ||
consts = [ | ||
inp_type.filter_variable(inp) | ||
for inp_type, inp in zip(self.const_types, consts) | ||
] | ||
|
||
return Apply( | ||
self, | ||
[*states, *consts], | ||
[state_type() for state_type in self.state_types], | ||
) | ||
|
||
def infer_shape(self, fgraph, node, input_shapes): | ||
return input_shapes[: len(self.state_types)] | ||
|
||
def perform(self, node, inputs, output_storage): | ||
update_fn = self.update_fn | ||
|
||
states = inputs[: len(self.state_types)] | ||
consts = inputs[len(self.state_types) :] | ||
while True: | ||
go_on, *states = update_fn(*states, *consts) | ||
if not go_on: | ||
break | ||
|
||
for i, state in enumerate(states): | ||
output_storage[i][0] = state | ||
|
||
def L_Op(self, *args): | ||
if not self.reverse: | ||
raise NotImplementedError() | ||
# Use L_Op of self.reverse.update | ||
... | ||
|
||
def R_Op(self, *args): | ||
# Use R_op of self.update | ||
... | ||
|
||
|
||
class Scan(Op): | ||
"""Represent a scan. | ||
This Op must always be converted to a Loop during compilation | ||
""" | ||
|
||
def __init__( | ||
self, | ||
update: FunctionGraph, # (*state, *consts) -> (bool, *state) | ||
reverse: Optional[FunctionGraph] = None, | ||
): | ||
validate_loop_update_types(update) | ||
self.state_types = [out.type for out in update.outputs[1:]] | ||
self.trace_types: list[Type] = [] | ||
for state_type in self.state_types: | ||
# Accommodate SparseTensors and Scalars | ||
if isinstance(state_type, DenseTensorType): | ||
self.trace_types.append( | ||
DenseTensorType( | ||
shape=(None, *state_type.shape), dtype=state_type.dtype | ||
) | ||
) | ||
else: | ||
# We can't concatenate all types of states, such as RandomTypes | ||
self.trace_types.append(NoneConst.type) | ||
self.const_types = [inp.type for inp in update.inputs[len(self.state_types) :]] | ||
self.update = update | ||
self.reverse = reverse | ||
self._update_fn = None | ||
|
||
def make_node(self, n_steps, *inputs): | ||
assert len(inputs) == len(self.state_types) + len(self.const_types) | ||
|
||
n_steps = TensorType(dtype="int64", shape=()).filter_variable(n_steps) | ||
|
||
states = inputs[: len(self.state_types)] | ||
states = [ | ||
inp_type.filter_variable(inp) | ||
for inp_type, inp in zip(self.state_types, states) | ||
] | ||
|
||
consts = inputs[len(self.state_types) :] | ||
consts = [ | ||
inp_type.filter_variable(inp) | ||
for inp_type, inp in zip(self.const_types, consts) | ||
] | ||
|
||
return Apply( | ||
self, | ||
[n_steps, *states, *consts], | ||
[output_type() for output_type in self.state_types + self.trace_types], | ||
) | ||
|
||
def infer_shape(self, fgraph, node, input_shapes): | ||
n_steps = node.inputs[0] | ||
state_shapes = input_shapes[1 : len(self.state_types) + 1] | ||
trace_shapes = [ | ||
(n_steps, *state_shape) if state_shape is not None else None | ||
for state_shape in state_shapes | ||
] | ||
return state_shapes + trace_shapes | ||
|
||
def perform(self, node, inputs, output_storage): | ||
raise RuntimeError("Loop Op should not be present in compiled graph") | ||
|
||
def L_op(self, *args): | ||
# Use trace outputs | ||
... | ||
|
||
def R_op(self, *args): | ||
# Use R_op of self.update | ||
... | ||
|
||
|
||
@node_rewriter([Scan]) | ||
def scan_to_loop(fgraph, node): | ||
"""Rewrite a Scan Op into a Loop Op""" | ||
op: Scan = node.op # type: ignore | ||
|
||
n_steps = node.inputs[0] | ||
|
||
n_state_vars = len(op.state_types) | ||
old_states = node.outputs[:n_state_vars] | ||
old_traces = node.outputs[n_state_vars:] | ||
|
||
# Only include the intermediate states that are used elsewhere | ||
used_traces_idxs = [ | ||
i | ||
for i, trace in enumerate(node.outputs[n_state_vars:]) | ||
if fgraph.clients[trace] | ||
] | ||
|
||
# Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced | ||
for trace_idx in used_traces_idxs: | ||
assert not isinstance(old_states[trace_idx].type, NoneTypeT) | ||
|
||
update_fg = op.update | ||
prev_idx = scalar(dtype="int64", name="prev_idx") | ||
prev_states = update_fg.inputs[:n_state_vars] | ||
prev_traces = [old_traces[i].type() for i in used_traces_idxs] | ||
consts = update_fg.inputs[n_state_vars * 2 :] | ||
|
||
go_on, *next_states = update_fg.outputs | ||
next_idx = prev_idx + 1 | ||
next_idx.name = "next_idx" | ||
next_traces = [ | ||
set_subtensor(prev_trace[prev_idx], next_states[trace_idx]) | ||
for trace_idx, prev_trace in zip(used_traces_idxs, prev_traces) | ||
] | ||
go_on = and_(go_on, next_idx < n_steps) | ||
go_on.name = "go_on" | ||
|
||
new_update_fg = FunctionGraph( | ||
inputs=[prev_idx, *prev_states, *prev_traces, *consts], | ||
outputs=[go_on, next_idx, *next_states, *next_traces], | ||
) | ||
|
||
# TODO: Implement Reverse? | ||
loop_op = Loop(update=new_update_fg) | ||
|
||
init_idx = constant(np.array(0, dtype="int64")) | ||
init_states = node.inputs[1 : len(op.state_types) + 1] | ||
init_traces = [ | ||
empty((n_steps, *tuple(init_states[trace_idx].shape))) | ||
for trace_idx in used_traces_idxs | ||
] | ||
final_idx, *new_outs = loop_op(init_idx, *init_states, *init_traces) | ||
new_states = new_outs[:n_state_vars] | ||
new_traces = new_outs[n_state_vars:] | ||
|
||
replacements = dict(zip(old_states, new_states)) | ||
for trace_idx, new_trace in zip(used_traces_idxs, new_traces): | ||
replacements[old_traces[trace_idx]] = new_trace[:final_idx] | ||
return replacements | ||
|
||
|
||
# TODO: Create new Loop dataset | ||
optdb.register( | ||
"scan_to_loop", | ||
in2out(scan_to_loop), | ||
"fast_compile", | ||
"fast_run", | ||
position=-0.1, # TODO: When? | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
|
||
from pytensor import function, shared | ||
from pytensor.graph import FunctionGraph | ||
from pytensor.loop.op import Loop, Scan | ||
from pytensor.tensor import constant, lscalar, scalar | ||
from pytensor.tensor.random import normal | ||
from pytensor.tensor.type_other import NoneTypeT | ||
|
||
|
||
def test_loop_basic(): | ||
i = lscalar("i") | ||
x = scalar("x") | ||
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) | ||
|
||
_, y = Loop(update=update_fg)(np.array(0, dtype="int64"), x) | ||
assert y.eval({x: 0}) == 20 | ||
|
||
|
||
def test_for_scan(): | ||
x = scalar("x") | ||
update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2]) | ||
|
||
n_steps = 10 | ||
y, ys = Scan(update=update_fg)(n_steps, x) | ||
|
||
fn = function([x], [y, ys]) | ||
|
||
loop_nodes = tuple( | ||
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) | ||
) | ||
assert len(loop_nodes) == 1 | ||
(loop_node,) = loop_nodes | ||
print(loop_node.inputs) | ||
assert len(loop_node.outputs) == 3 | ||
assert loop_node.outputs[0].type.shape == () | ||
assert loop_node.outputs[1].type.shape == () | ||
assert loop_node.outputs[2].type.shape == (None,) # This could be known | ||
|
||
y_eval, ys_eval = fn(0) | ||
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) | ||
np.testing.assert_array_equal(ys_eval[-1], y_eval) | ||
|
||
|
||
def test_while_scan(): | ||
i = lscalar("i") | ||
x = scalar("x") | ||
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2]) | ||
|
||
n_steps = 1000 | ||
_, y, _, ys = Scan(update=update_fg)(n_steps, np.array(0, dtype="int64"), x) | ||
|
||
fn = function([x], [y, ys]) | ||
|
||
loop_nodes = tuple( | ||
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop) | ||
) | ||
assert len(loop_nodes) == 1 | ||
(loop_node,) = loop_nodes | ||
print(loop_node.inputs) | ||
assert len(loop_node.outputs) == 4 | ||
assert loop_node.outputs[0].type.shape == () | ||
assert loop_node.outputs[1].type.shape == () | ||
assert loop_node.outputs[2].type.shape == () | ||
assert loop_node.outputs[3].type.shape == (None,) # This could be known | ||
|
||
y_eval, ys_eval = fn(0) | ||
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2)) | ||
np.testing.assert_array_equal(ys_eval[-1], y_eval) | ||
|
||
|
||
def test_scan_random(): | ||
rng_test = np.random.default_rng(123) | ||
rng_shared = shared(np.random.default_rng(123)) | ||
n_steps = 5 | ||
|
||
x = scalar( | ||
"x" | ||
) # TODO: x shouldn't be needed when the initial_state does not matter! | ||
rng = rng_shared.type() | ||
update_fg = FunctionGraph( | ||
[x, rng], [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]] | ||
) | ||
|
||
_, new_rng, ys, rngs = Scan(update=update_fg)(n_steps, x, rng_shared) | ||
assert isinstance(rngs.type, NoneTypeT) | ||
|
||
fn = function([x], ys, updates={rng_shared: new_rng}) | ||
|
||
np.testing.assert_array_equal(fn(0), rng_test.normal(size=5)) | ||
np.testing.assert_array_equal(fn(0), rng_test.normal(size=5)) |