Skip to content

Commit

Permalink
WIP new Loop operator
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 10, 2023
1 parent b8831aa commit 5b25de8
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 0 deletions.
Empty file added pytensor/loop/__init__.py
Empty file.
247 changes: 247 additions & 0 deletions pytensor/loop/op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
from typing import Optional

import numpy as np

from pytensor import In, Out
from pytensor.compile import optdb, pfunc
from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter
from pytensor.graph.rewriting.basic import in2out
from pytensor.scalar import constant
from pytensor.tensor import NoneConst, and_, empty, scalar, set_subtensor
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.type_other import NoneTypeT


def validate_loop_update_types(update):
assert update.outputs[0].type.dtype == "bool"
for input_state, output_state in zip(update.inputs, update.outputs[1:]):
assert input_state.type == output_state.type


class Loop(Op):
"""Represent a do-while loop."""

def __init__(
self,
update: FunctionGraph, # (*state, *consts) -> (bool, *state)
reverse: Optional[FunctionGraph] = None,
):
validate_loop_update_types(update)
self.state_types = [out.type for out in update.outputs[1:]]
self.const_types = [inp.type for inp in update.inputs[len(self.state_types) :]]
self.update = update
self.reverse = reverse
self._update_fn = None

@property
def update_fn(self):
"""Lazily compile the inner update function graph."""
if self._update_fn is not None:
return self._update_fn

fgraph = self.update
wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs]

self._update_fn = pfunc(
wrapped_inputs,
wrapped_outputs,
mode="FAST_RUN", # TODO: Figure this out
accept_inplace=False,
on_unused_input="ignore",
fgraph=fgraph,
)
return self._update_fn

def make_node(self, *inputs):
assert len(inputs) == len(self.state_types) + len(self.const_types)

states = inputs[: len(self.state_types)]
states = [
inp_type.filter_variable(inp)
for inp_type, inp in zip(self.state_types, states)
]

consts = inputs[len(self.state_types) :]
consts = [
inp_type.filter_variable(inp)
for inp_type, inp in zip(self.const_types, consts)
]

return Apply(
self,
[*states, *consts],
[state_type() for state_type in self.state_types],
)

def infer_shape(self, fgraph, node, input_shapes):
return input_shapes[: len(self.state_types)]

def perform(self, node, inputs, output_storage):
update_fn = self.update_fn

states = inputs[: len(self.state_types)]
consts = inputs[len(self.state_types) :]
while True:
go_on, *states = update_fn(*states, *consts)
if not go_on:
break

for i, state in enumerate(states):
output_storage[i][0] = state

def L_Op(self, *args):
if not self.reverse:
raise NotImplementedError()
# Use L_Op of self.reverse.update
...

def R_Op(self, *args):
# Use R_op of self.update
...


class Scan(Op):
"""Represent a scan.
This Op must always be converted to a Loop during compilation
"""

def __init__(
self,
update: FunctionGraph, # (*state, *consts) -> (bool, *state)
reverse: Optional[FunctionGraph] = None,
):
validate_loop_update_types(update)
self.state_types = [out.type for out in update.outputs[1:]]
self.trace_types: list[Type] = []
for state_type in self.state_types:
# Accommodate SparseTensors and Scalars
if isinstance(state_type, DenseTensorType):
self.trace_types.append(
DenseTensorType(
shape=(None, *state_type.shape), dtype=state_type.dtype
)
)
else:
# We can't concatenate all types of states, such as RandomTypes
self.trace_types.append(NoneConst.type)
self.const_types = [inp.type for inp in update.inputs[len(self.state_types) :]]
self.update = update
self.reverse = reverse
self._update_fn = None

def make_node(self, n_steps, *inputs):
assert len(inputs) == len(self.state_types) + len(self.const_types)

n_steps = TensorType(dtype="int64", shape=()).filter_variable(n_steps)

states = inputs[: len(self.state_types)]
states = [
inp_type.filter_variable(inp)
for inp_type, inp in zip(self.state_types, states)
]

consts = inputs[len(self.state_types) :]
consts = [
inp_type.filter_variable(inp)
for inp_type, inp in zip(self.const_types, consts)
]

return Apply(
self,
[n_steps, *states, *consts],
[output_type() for output_type in self.state_types + self.trace_types],
)

def infer_shape(self, fgraph, node, input_shapes):
n_steps = node.inputs[0]
state_shapes = input_shapes[1 : len(self.state_types) + 1]
trace_shapes = [
(n_steps, *state_shape) if state_shape is not None else None
for state_shape in state_shapes
]
return state_shapes + trace_shapes

def perform(self, node, inputs, output_storage):
raise RuntimeError("Loop Op should not be present in compiled graph")

def L_op(self, *args):
# Use trace outputs
...

def R_op(self, *args):
# Use R_op of self.update
...


@node_rewriter([Scan])
def scan_to_loop(fgraph, node):
"""Rewrite a Scan Op into a Loop Op"""
op: Scan = node.op # type: ignore

n_steps = node.inputs[0]

n_state_vars = len(op.state_types)
old_states = node.outputs[:n_state_vars]
old_traces = node.outputs[n_state_vars:]

# Only include the intermediate states that are used elsewhere
used_traces_idxs = [
i
for i, trace in enumerate(node.outputs[n_state_vars:])
if fgraph.clients[trace]
]

# Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
for trace_idx in used_traces_idxs:
assert not isinstance(old_states[trace_idx].type, NoneTypeT)

update_fg = op.update
prev_idx = scalar(dtype="int64", name="prev_idx")
prev_states = update_fg.inputs[:n_state_vars]
prev_traces = [old_traces[i].type() for i in used_traces_idxs]
consts = update_fg.inputs[n_state_vars * 2 :]

go_on, *next_states = update_fg.outputs
next_idx = prev_idx + 1
next_idx.name = "next_idx"
next_traces = [
set_subtensor(prev_trace[prev_idx], next_states[trace_idx])
for trace_idx, prev_trace in zip(used_traces_idxs, prev_traces)
]
go_on = and_(go_on, next_idx < n_steps)
go_on.name = "go_on"

new_update_fg = FunctionGraph(
inputs=[prev_idx, *prev_states, *prev_traces, *consts],
outputs=[go_on, next_idx, *next_states, *next_traces],
)

# TODO: Implement Reverse?
loop_op = Loop(update=new_update_fg)

init_idx = constant(np.array(0, dtype="int64"))
init_states = node.inputs[1 : len(op.state_types) + 1]
init_traces = [
empty((n_steps, *tuple(init_states[trace_idx].shape)))
for trace_idx in used_traces_idxs
]
final_idx, *new_outs = loop_op(init_idx, *init_states, *init_traces)
new_states = new_outs[:n_state_vars]
new_traces = new_outs[n_state_vars:]

replacements = dict(zip(old_states, new_states))
for trace_idx, new_trace in zip(used_traces_idxs, new_traces):
replacements[old_traces[trace_idx]] = new_trace[:final_idx]
return replacements


# TODO: Create new Loop dataset
optdb.register(
"scan_to_loop",
in2out(scan_to_loop),
"fast_compile",
"fast_run",
position=-0.1, # TODO: When?
)
Empty file added tests/loop/__init__.py
Empty file.
91 changes: 91 additions & 0 deletions tests/loop/test_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np

from pytensor import function, shared
from pytensor.graph import FunctionGraph
from pytensor.loop.op import Loop, Scan
from pytensor.tensor import constant, lscalar, scalar
from pytensor.tensor.random import normal
from pytensor.tensor.type_other import NoneTypeT


def test_loop_basic():
i = lscalar("i")
x = scalar("x")
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])

_, y = Loop(update=update_fg)(np.array(0, dtype="int64"), x)
assert y.eval({x: 0}) == 20


def test_for_scan():
x = scalar("x")
update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2])

n_steps = 10
y, ys = Scan(update=update_fg)(n_steps, x)

fn = function([x], [y, ys])

loop_nodes = tuple(
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop)
)
assert len(loop_nodes) == 1
(loop_node,) = loop_nodes
print(loop_node.inputs)
assert len(loop_node.outputs) == 3
assert loop_node.outputs[0].type.shape == ()
assert loop_node.outputs[1].type.shape == ()
assert loop_node.outputs[2].type.shape == (None,) # This could be known

y_eval, ys_eval = fn(0)
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2))
np.testing.assert_array_equal(ys_eval[-1], y_eval)


def test_while_scan():
i = lscalar("i")
x = scalar("x")
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])

n_steps = 1000
_, y, _, ys = Scan(update=update_fg)(n_steps, np.array(0, dtype="int64"), x)

fn = function([x], [y, ys])

loop_nodes = tuple(
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop)
)
assert len(loop_nodes) == 1
(loop_node,) = loop_nodes
print(loop_node.inputs)
assert len(loop_node.outputs) == 4
assert loop_node.outputs[0].type.shape == ()
assert loop_node.outputs[1].type.shape == ()
assert loop_node.outputs[2].type.shape == ()
assert loop_node.outputs[3].type.shape == (None,) # This could be known

y_eval, ys_eval = fn(0)
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2))
np.testing.assert_array_equal(ys_eval[-1], y_eval)


def test_scan_random():
rng_test = np.random.default_rng(123)
rng_shared = shared(np.random.default_rng(123))
n_steps = 5

x = scalar(
"x"
) # TODO: x shouldn't be needed when the initial_state does not matter!
rng = rng_shared.type()
update_fg = FunctionGraph(
[x, rng], [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]]
)

_, new_rng, ys, rngs = Scan(update=update_fg)(n_steps, x, rng_shared)
assert isinstance(rngs.type, NoneTypeT)

fn = function([x], ys, updates={rng_shared: new_rng})

np.testing.assert_array_equal(fn(0), rng_test.normal(size=5))
np.testing.assert_array_equal(fn(0), rng_test.normal(size=5))

0 comments on commit 5b25de8

Please sign in to comment.