From da2bacd53e27e61185ea53d4df7c6ada68651d14 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 9 May 2022 15:29:09 -0700 Subject: [PATCH 1/2] Correct ListVariable source --- tests/test_misc.py | 12 ++++++++++++ torchdynamo/utils.py | 2 +- torchdynamo/variables/lists.py | 9 ++++++--- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/test_misc.py b/tests/test_misc.py index ee00084da0..aec57857f2 100755 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -1157,3 +1157,15 @@ def fn(): inst = dis.get_instructions(fn) result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) self.assertTrue(result[1] == fn.__code__.co_lnotab) + + def test_python_slice(self): + def fn(input): + y = 0 + for i, x in enumerate(input[2:], 1): + y = y + x + return y + + cnts = torchdynamo.testing.CompileCounter() + with torchdynamo.optimize(cnts): + z = fn([1, 2, 3, 5]) + self.assertEqual(z, 8) diff --git a/torchdynamo/utils.py b/torchdynamo/utils.py index 38763f953a..3fe52239a7 100644 --- a/torchdynamo/utils.py +++ b/torchdynamo/utils.py @@ -319,7 +319,7 @@ def rot_n_helper(n): def is_safe_constant(v): if istype(v, (tuple, frozenset)): return all(map(is_safe_constant, v)) - return istype(v, (types.CodeType, int, float, bool, str, bytes, type(None))) + return istype(v, (types.CodeType, int, float, bool, str, bytes, type(None), slice)) def check_constant_args(args, kwargs): diff --git a/torchdynamo/variables/lists.py b/torchdynamo/variables/lists.py index b0892aeb05..5d9cc80bc2 100644 --- a/torchdynamo/variables/lists.py +++ b/torchdynamo/variables/lists.py @@ -9,6 +9,7 @@ from ..utils import namedtuple_fields from .base import MutableLocal from .base import VariableTracker +from ..source import GetItemSource class BaseListVariable(VariableTracker): @@ -40,9 +41,11 @@ def as_proxy(self): def getitem_const(self, arg: VariableTracker): index = arg.as_python_constant() if isinstance(index, slice): - return self.clone(items=self.items[index], mutable_local=None).add_options( - arg, self - ) + return self.clone( + items=self.items[index], + source=GetItemSource(self.source, index), + mutable_local=None, + ).add_options(arg, self) else: assert isinstance(index, int) return self.items[index].add_options(arg, self) From 374c27d1c8770055cf8cea4ccb56cca91fc10d54 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 9 May 2022 16:15:39 -0700 Subject: [PATCH 2/2] Fix lint --- torchdynamo/variables/lists.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdynamo/variables/lists.py b/torchdynamo/variables/lists.py index 5d9cc80bc2..3c4c79c7a7 100644 --- a/torchdynamo/variables/lists.py +++ b/torchdynamo/variables/lists.py @@ -6,10 +6,10 @@ from .. import variables from ..bytecode_transformation import create_instruction from ..exc import unimplemented +from ..source import GetItemSource from ..utils import namedtuple_fields from .base import MutableLocal from .base import VariableTracker -from ..source import GetItemSource class BaseListVariable(VariableTracker):