Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions torch/_dynamo/variables/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,7 @@ def call_method(
{},
)
elif name == "__iter__":
return ListIteratorVariable(
list(self.items), mutation_type=ValueMutationNew()
)
return ListIteratorVariable(self.items, mutation_type=ValueMutationNew())

return super().call_method(tx, name, args, kwargs)

Expand Down Expand Up @@ -1589,14 +1587,16 @@ def __init__(self, items, index: int = 0, **kwargs) -> None:
# assert all(isinstance(x, VariableTracker) for x in items)
self.items = items
self.index = index
self.is_exhausted = False

def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"

def next_variable(self, tx):
assert self.is_mutable()
old_index = self.index
if old_index >= len(self.items):
if old_index >= len(self.items) or self.is_exhausted:
self.is_exhausted = True
raise_observed_exception(StopIteration, tx)

tx.output.side_effects.mutation(self)
Expand All @@ -1618,15 +1618,19 @@ def has_unpack_var_sequence(self, tx):
return True

def unpack_var_sequence(self, tx):
r = list(self.items[self.index :])
self.index = len(self.items)
return r
if self.is_exhausted:
return []
self.is_exhausted = True
return list(self.items[self.index :])

def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
return self.unpack_var_sequence(tx)

def reconstruct(self, codegen: "PyCodegen") -> None:
remaining_items = self.items[self.index :]
if not self.is_exhausted:
remaining_items = self.items[self.index :]
else:
remaining_items = []
codegen.foreach(remaining_items)
codegen.extend_output(
[
Expand Down
Loading