diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index fdb23e3f590f3..472e9c56bae63 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1164,6 +1164,32 @@ def test_tuple_contains(a, b): return a + b return a - b + @unittest.skipIf( + sys.version_info < (3, 9), + "SET_UPDATE was added at Python 3.9", + ) + @make_test + def test_set_update_bytecode(x): + # This produces bytecode SET_UPDATE since python 3.9 + var = {"apple", "banana", "cherry"} + if isinstance(var, set): + return x + 1 + else: + return x - 1 + + @unittest.skipIf( + sys.version_info < (3, 9), + "SET_UPDATE was added at Python 3.9", + ) + @make_test + def test_set_update_list_with_duplicated_items(x): + list1 = ["apple", "banana", "apple"] + list2 = ["orange", "banana"] + if len({*list1, *list2}) == 3: + return x + 1 + else: + return x - 1 + @make_test def test_set_contains(a, b): vals = set(["a", "b", "c"]) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d6fb3e2145b73..4b4d6d3de6755 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1504,6 +1504,14 @@ def SET_ADD(self, inst): assert obj.mutable_local return obj.call_method(self, "add", [v], {}) + def SET_UPDATE(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.mutable_local + obj.call_method(self, "update", [v], {}) + def LIST_APPEND(self, inst): v = self.pop() assert inst.argval > 0 diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index c8eabc2c88799..0724a80621f76 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -407,6 +407,8 @@ def call_method( args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> "VariableTracker": + from . import ListVariable, TupleVariable + # We foward the calls to the dictionary model if name == "add": assert not kwargs @@ -426,6 +428,24 @@ def call_method( return variables.UserFunctionVariable( polyfill.set_isdisjoint ).call_function(tx, [self, args[0]], {}) + elif ( + name == "update" + and len(args) == 1 + and isinstance( + args[0], + ( + SetVariable, + ListVariable, + TupleVariable, + ), + ) + and self.mutable_local + ): + if isinstance(args[0], (ListVariable, TupleVariable)): + arg = SetVariable(args[0].unpack_var_sequence(tx)) + else: + arg = args[0] + return super().call_method(tx, "update", (arg,), kwargs) return super().call_method(tx, name, args, kwargs) def getitem_const(self, arg: VariableTracker):