Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] add infinite generators itertools.{count, repeat, cycle} #110967

Closed
97 changes: 97 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7578,6 +7578,103 @@ def fn(x):
self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_infinite_repeat(self):
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
counters.clear()

def fn(x):
r = itertools.repeat(100.0)
idx = 0
for i in r:
x += i
idx += 1
if idx > 10:
break
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_infinite_repeat_mutation(self):
counters.clear()

def fn(x):
r = itertools.repeat(x)
idx = 0
for i in r:
x += i
i += 1
idx += 1
if idx > 10:
break
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved

def test_itertools_infinite_count(self):
for args in ([], [10], [5, -1]):
counters.clear()

def fn(x):
r = itertools.count(*args)
idx = 0
for i in r:
x += i
idx += 1
if idx > 10:
break
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_infinite_cycle(self):
counters.clear()

def fn(x):
for iterator in (
iter([]),
iter([10, 11.0]),
itertools.repeat(-1, 3),
itertools.count(10),
):
r = itertools.cycle(iterator)
idx = 0
x += 1
for i in r:
x += i
idx += 1
if idx > 10:
break
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_accumulate_symint_default_sum(self):
# https://github.com/pytorch/pytorch/issues/110287
counters.clear()
Expand Down
12 changes: 6 additions & 6 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,11 +1065,10 @@ def POP_FINALLY(self, inst):

def FOR_ITER(self, inst):
it = self.pop()
if isinstance(it, ListIteratorVariable):
if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
self.output.guards.update(it.guards)
try:
val, next_iter = it.next_variables()
self.replace_all(it, next_iter)
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
val, next_iter = it.next_variables(self)
self.push(next_iter)
self.push(val)
except StopIteration:
Expand Down Expand Up @@ -2559,11 +2558,12 @@ def YIELD_FROM(self, inst):
if isinstance(tos, ConstantVariable) and tos.value is None:
self.pop()
return
if isinstance(tos, ListIteratorVariable):
if isinstance(
tos, (variables.ListIteratorVariable, variables.IteratorVariable)
):
self.output.guards.update(tos.guards)
try:
val, next_iter = tos.next_variables()
self.replace_all(tos, next_iter)
val, next_iter = tos.next_variables(self)
self.push(val)
# TODO(voz): Unclear if we need the push None in YIELD_VALUE?
self.YIELD_VALUE(inst)
Expand Down
10 changes: 10 additions & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
UserMethodVariable,
)
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .iter import (
CountIteratorVariable,
CycleIteratorVariable,
IteratorVariable,
RepeatIteratorVariable,
)
from .lists import (
BaseListVariable,
ListIteratorVariable,
Expand Down Expand Up @@ -79,6 +85,10 @@
"GetAttrVariable",
"GradModeVariable",
"InspectSignatureVariable",
"IteratorVariable",
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
"RepeatIteratorVariable",
"CountIteratorVariable",
"CycleIteratorVariable",
"LambdaVariable",
"ListIteratorVariable",
"ListVariable",
Expand Down
13 changes: 10 additions & 3 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,12 @@ def _dyn_proxy(self, tx, *args, **kwargs):
def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
if self._dynamic_args(*args, **kwargs):
return self._dyn_proxy(tx, *args, **kwargs)

if isinstance(obj, variables.IteratorVariable):
# For non-list iterators, we will guard on vars that
# determine the control flow
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
return obj

# TODO This should probably be treated as a dict, or dicts should also be treated here
if self.fn == set:
cls = SetVariable
Expand Down Expand Up @@ -965,9 +971,10 @@ def call_super(self, tx, a, b):
return variables.SuperVariable(a, b)

def call_next(self, tx, arg):
if isinstance(arg, variables.ListIteratorVariable):
val, next_iter = arg.next_variables()
tx.replace_all(arg, next_iter)
if isinstance(
arg, (variables.ListIteratorVariable, variables.IteratorVariable)
):
val, next_iter = arg.next_variables(tx)
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
return val
elif isinstance(arg, variables.BaseListVariable):
return arg.items[0].add_options(self, arg)
Expand Down
101 changes: 101 additions & 0 deletions torch/_dynamo/variables/iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
MAX_CYCLE = 3000

from typing import List, Optional

from ..exc import unimplemented

from .base import VariableTracker
from .constant import ConstantVariable


class IteratorVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def next_variables(self, tx):
unimplemented("abstract method, must implement")


class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs):
super().__init__(**kwargs)
self.item = item

# Repeat needs no mutation, clone self
def next_variables(self, tx):
# add_options will clone self.item
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
return self.item.add_options(self), self


class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs):
super().__init__(**kwargs)
if not isinstance(item, VariableTracker):
item = ConstantVariable.create(item)
if not isinstance(step, VariableTracker):
step = ConstantVariable.create(step)
self.item = item
self.step = step

def next_variables(self, tx):
assert self.mutable_local
next_item = self.item.call_method(tx, "__add__", [self.step], {})
next_iter = self.clone(item=next_item)
tx.replace_all(self, next_iter)
return self.item.add_options(self), next_iter


class CycleIteratorVariable(IteratorVariable):
def __init__(
self,
iterator: IteratorVariable,
saved: List[VariableTracker] = None,
saved_index: int = 0,
item: Optional[VariableTracker] = None,
**kwargs,
):
if saved is None:
saved = []
super().__init__(**kwargs)
self.iterator = iterator
self.saved = saved
self.saved_index = saved_index
self.item = item

def next_variables(self, tx):
assert self.mutable_local

if self.iterator is not None:
try:
new_item, next_inner_iter = self.iterator.next_variables(tx)
tx.replace_all(self.iterator, next_inner_iter)
if len(self.saved) > MAX_CYCLE:
unimplemented(
"input iterator to itertools.cycle has too many items"
)
next_iter = self.clone(
iterator=next_inner_iter,
saved=self.saved + [new_item],
item=new_item,
)

tx.replace_all(self, next_iter)
if self.item is None:
return next_iter.next_variables(tx)
return self.item.add_options(self), next_iter
except StopIteration:
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
next_iter = self.clone(iterator=None)
# this is redundant as next_iter will do the same
# but we do it anyway for safety
tx.replace_all(self, next_iter)
return next_iter.next_variables(tx)
elif len(self.saved) > 0:
next_iter = self.clone(
saved_index=(self.saved_index + 1) % len(self.saved),
item=self.saved[self.saved_index],
)
tx.replace_all(self, next_iter)
return self.item.add_options(self), next_iter
else:
raise StopIteration
return self.item.add_options(self), next_iter
6 changes: 4 additions & 2 deletions torch/_dynamo/variables/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,16 +683,18 @@ def __init__(self, items, index: int = 0, **kwargs):
self.items = items
self.index = index

def next_variables(self):
def next_variables(self, tx):
assert self.mutable_local
if self.index >= len(self.items):
raise StopIteration()
return self.items[self.index].add_options(self), ListIteratorVariable(
next_iter = ListIteratorVariable(
self.items,
self.index + 1,
mutable_local=MutableLocal(),
**VariableTracker.propagate([self]),
)
tx.replace_all(self, next_iter)
ezyang marked this conversation as resolved.
Show resolved Hide resolved
return self.items[self.index].add_options(self), next_iter

def as_python_constant(self):
if self.index > 0:
Expand Down
14 changes: 9 additions & 5 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,16 +885,20 @@ def wraps(fn):
fn, args=rest_args, keywords=kwargs, **options
)
elif self.value is itertools.repeat:
from .builder import SourcelessBuilder

if len(args) < 2:
# We cannot risk infinite generator being consumed to exhaustion by dynamo
# (i.e. infinite loop)
unimplemented("Infinite repeat is not supported")
return variables.RepeatIteratorVariable(
*args, mutable_local=MutableLocal()
)

from .builder import SourcelessBuilder

return tx.inline_user_function_return(
SourcelessBuilder()(tx, polyfill.repeat), args, kwargs
)
elif self.value is itertools.count:
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
elif self.value is itertools.cycle:
return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
else:
try:
path = inspect.getfile(self.value)
Expand Down
Loading