-
Notifications
You must be signed in to change notification settings - Fork 21.6k
/
while_loop.py
268 lines (214 loc) · 10.2 KB
/
while_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# mypy: allow-untyped-defs
from typing import Callable, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
class WhileLoopOp(HigherOrderOperator):
def __init__(self):
super().__init__("while_loop")
def __call__(
self,
cond_fn: Callable,
body_fn: Callable,
carried_inputs: Tuple[Union[torch.Tensor, int, float, bool]],
additional_inputs: Tuple[Union[torch.Tensor, int, float, bool]],
/,
):
if not isinstance(carried_inputs, tuple):
raise RuntimeError(
f"carried_inputs must be a tuple, got {type(carried_inputs)}"
)
if not isinstance(additional_inputs, tuple):
raise RuntimeError(
f"additional_inputs must be a tuple, got {type(additional_inputs)}"
)
if not all(
isinstance(t, (torch.Tensor, int, float, bool)) for t in carried_inputs
):
raise RuntimeError(
"carried_inputs must be a tuple of tensors, ints, floats, or bools, got "
f"{carried_inputs}"
)
if not all(
isinstance(t, (torch.Tensor, int, float, bool)) for t in additional_inputs
):
raise RuntimeError(
"additional_inputs must be a tuple of tensors, ints, floats, or bools, got "
f"{additional_inputs}"
)
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
while_loop_op = WhileLoopOp()
# Override while_loop_op.__module__ to "torch.ops.higher_order" so that in the generated
# graph module, while_loop node's target is correctedly printed as torch.ops.higher_order.while_loop
while_loop_op.__module__ = "torch.ops.higher_order"
def while_loop(cond_fn, body_fn, carried_inputs):
r"""
Run body_fn(*carried_inputs) while cond_fn(*carried_inputs) returns a True scalar tensor. Returns the output of body_fn or
initial carried_inputs.
.. warning::
`torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
`while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export.
`while_loop` is equivalent to the following:
def while_loop(cond_fn, body_fn, carried_inputs):
val = carried_inputs
while cond_fn(*val):
val = body_fn(*val)
return val
Args:
cond_fn (Callable): A callable function that returns a boolean Scalar tensor.
body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors
carried_inputs (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also
the initial value of states that are carried across iterations.
Example:
def cond_fn(iter, x):
return iter.sum() < 10
def body_fn(iter, x):
return iter + 1, x.sin()
while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
Restrictions:
- body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs.
- body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required.
- body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn.
- body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required.
.. warning::
Temporal Limitations:
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
"""
# Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
# parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
additional_inputs: Tuple = tuple()
if torch.compiler.is_dynamo_compiling():
return while_loop_op(cond_fn, body_fn, carried_inputs, additional_inputs)
def _validate_input(cond_fn, body_fn, carried_inputs):
if not callable(cond_fn) or not callable(body_fn):
raise RuntimeError("Expect cond_fn and body_fn to be callbale.")
if not isinstance(carried_inputs, (tuple, list)) or pytree.tree_any(
lambda t: not isinstance(t, torch.Tensor), carried_inputs
):
raise RuntimeError(
"Expect carried_inputs to be a tuple of possibly nested dict/list/tuple that only"
f"consists of tensor leaves, but got {carried_inputs}."
)
_validate_input(cond_fn, body_fn, carried_inputs)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(while_loop_op, backend="eager", fullgraph=True)(
cond_fn, body_fn, carried_inputs, additional_inputs
)
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
carried_vals = carried_inputs
def _is_boolean_scalar_tensor(pred):
return (
isinstance(pred, torch.Tensor)
and pred.size() == torch.Size([])
and pred.dtype == torch.bool
)
if not isinstance(carried_inputs, tuple):
raise RuntimeError(
f"carried_inputs must be a tuple but got {type(carried_inputs)}"
)
while pred := cond_fn(*carried_vals, *additional_inputs):
if not _is_boolean_scalar_tensor(pred):
raise RuntimeError(
f"cond_fn must return a boolean scalar tensor but got {pred}"
)
out = body_fn(*carried_vals, *additional_inputs)
assert isinstance(
out, tuple
), f"body_fn should return a tuple but got {type(out)}"
assert len(out) == len(
carried_inputs
), "body_fn should return the same number of elements as carried_inputs"
carried_vals = out
return carried_vals
while_loop_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(while_loop_op, deferred_error=True)
)
@while_loop_op.py_impl(ProxyTorchDispatchMode)
def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
def _trace_while_loop(
proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
):
cond_graph = reenter_make_fx(cond_fn)(*carried_inputs, *additional_inputs)
body_graph = reenter_make_fx(body_fn)(*carried_inputs, *additional_inputs)
next_name = None
i = 0
while not next_name:
candidate = f"while_loop_cond_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
cond_graph_name = next_name
body_graph_name = f"while_loop_body_graph_{i}"
assert not hasattr(proxy_mode.tracer.root, body_graph_name)
proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
args = (cond_graph, body_graph, carried_inputs, additional_inputs)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", while_loop_op, proxy_args, {}, name="while_loop"
)
# body_fn return output with the same pytree and tensor meta data as carried_inputs
# so we could just return the output after one iteration.
out = body_fn(*carried_inputs, *additional_inputs)
return track_tensor_tree(
out, out_proxy, constant=None, tracer=proxy_mode.tracer
)
if mode.enable_tracing:
return _trace_while_loop(
mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
)
else:
return while_loop_op(cond_fn, body_fn, carried_inputs, additional_inputs)
@while_loop_op.py_impl(FakeTensorMode)
def while_loop_fake_tensor_mode(
mode, cond_fn, body_fn, carried_inputs, additional_inputs
):
with mode:
return body_fn(*carried_inputs, *additional_inputs)
@while_loop_op.py_functionalize_impl
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
with ctx.redispatch_to_next() as m:
functional_cond_fn = ctx.functionalize(cond_fn)
functional_body_fn = ctx.functionalize(body_fn)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
for fn, fn_name in [
(functional_cond_fn, "cond_fn"),
(functional_body_fn, "body_fn"),
]:
if _has_potential_branch_input_mutation(
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
f"torch.while_loop's {fn_name} might be modifying the input!"
)
if _has_potential_branch_input_alias(
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
f"torch.while_loop's {fn_name} might be aliasing the input!"
)
ret = while_loop_op(
functional_cond_fn,
functional_body_fn,
unwrapped_carried_inputs,
unwrapped_additional_inputs,
)
return ctx.wrap_tensors(ret)