Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] support dict.copy() / OrderedDict.copy() / defaultdict.copy() #115012

Closed
wants to merge 7 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ hf_T5,pass,0



hf_T5_generate,pass,20
hf_T5_generate,pass,18



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ hf_T5,pass,0



hf_T5_generate,fail_to_run,10
hf_T5_generate,fail_to_run,9



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ hf_T5,pass,0



hf_T5_generate,fail_to_run,10
hf_T5_generate,fail_to_run,9



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ hf_T5,pass,0



hf_T5_generate,pass,20
hf_T5_generate,pass,18



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ hf_T5,pass,0



hf_T5_generate,pass,20
hf_T5_generate,pass,18



Expand Down
14 changes: 14 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,20 @@ def test_dict_fromkeys(x, y):
d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1

@make_test
def test_dict_copy(x):
my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
d1 = dict(my_list)
d1["a"] = x + 10
d2 = d1.copy()
d2["a"] = x - 5
d2["b"] = x + 3
d3 = collections.OrderedDict(my_list)
d3["c"] = x + 20
d4 = d3.copy()
d4["c"] = x - 10
return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1
Comment on lines +820 to +832
Copy link
Collaborator Author

@XuehaiPan XuehaiPan Dec 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test for defaultdict.copy is not added. Because the collections module is ignored in SKIP_DIRS. Also seems that the implementation for DefaultDictVariable is incomplete. Currently, we only support creating an empty defaultdict with the default_factory.

dd = defaultdict(int)

We do not support passing a mapping / seq2 as the second argument. We also do not support keywords:

dd1 = defaultdict(int, {'a': 1, 'b': 2})
dd2 = defaultdict(int, zip('ab', range(2)))
dd3 = defaultdict(int, a=1, b=2)
dd4 = defaultdict(int, {'a': 1, 'b': 2}, c=3, d=4)


@make_test
def test_dict_update(x, y, z):
d = {"a": x, "b": y}
Expand Down
5 changes: 3 additions & 2 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def call_method(
if name == "__getitem__":
assert len(args) == 1
return self.getitem_const(args[0])

elif name == "items":
assert not (args or kwargs)
return TupleVariable(
Expand Down Expand Up @@ -118,10 +117,12 @@ def call_method(
],
mutable_local=MutableLocal(),
)

elif name == "values":
assert not (args or kwargs)
return TupleVariable(list(val.values()))
elif name == "copy":
assert not (args or kwargs)
return self.modifed(self.items.copy(), mutable_local=MutableLocal())
elif name == "__len__":
assert not (args or kwargs)
return ConstantVariable.create(len(self.items))
Expand Down
Loading