Skip to content

Commit

Permalink
[Dynamo] Support SET_UPDATE (#126243)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: #126243
Approved by: https://github.com/anijain2305, https://github.com/Skylion007, https://github.com/jansel
  • Loading branch information
yanboliang authored and ZelboK committed May 19, 2024
1 parent 45a699a commit b24a9e3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,32 @@ def test_tuple_contains(a, b):
return a + b
return a - b

@unittest.skipIf(
sys.version_info < (3, 9),
"SET_UPDATE was added at Python 3.9",
)
@make_test
def test_set_update_bytecode(x):
# This produces bytecode SET_UPDATE since python 3.9
var = {"apple", "banana", "cherry"}
if isinstance(var, set):
return x + 1
else:
return x - 1

@unittest.skipIf(
sys.version_info < (3, 9),
"SET_UPDATE was added at Python 3.9",
)
@make_test
def test_set_update_list_with_duplicated_items(x):
list1 = ["apple", "banana", "apple"]
list2 = ["orange", "banana"]
if len({*list1, *list2}) == 3:
return x + 1
else:
return x - 1

@make_test
def test_set_contains(a, b):
vals = set(["a", "b", "c"])
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,14 @@ def SET_ADD(self, inst):
assert obj.mutable_local
return obj.call_method(self, "add", [v], {})

def SET_UPDATE(self, inst):
v = self.pop()
assert inst.argval > 0
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.mutable_local
obj.call_method(self, "update", [v], {})

def LIST_APPEND(self, inst):
v = self.pop()
assert inst.argval > 0
Expand Down
20 changes: 20 additions & 0 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ def call_method(
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> "VariableTracker":
from . import ListVariable, TupleVariable

# We foward the calls to the dictionary model
if name == "add":
assert not kwargs
Expand All @@ -426,6 +428,24 @@ def call_method(
return variables.UserFunctionVariable(
polyfill.set_isdisjoint
).call_function(tx, [self, args[0]], {})
elif (
name == "update"
and len(args) == 1
and isinstance(
args[0],
(
SetVariable,
ListVariable,
TupleVariable,
),
)
and self.mutable_local
):
if isinstance(args[0], (ListVariable, TupleVariable)):
arg = SetVariable(args[0].unpack_var_sequence(tx))
else:
arg = args[0]
return super().call_method(tx, "update", (arg,), kwargs)
return super().call_method(tx, name, args, kwargs)

def getitem_const(self, arg: VariableTracker):
Expand Down

0 comments on commit b24a9e3

Please sign in to comment.