Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] support dict.fromkeys() / OrderedDict.fromkeys() / defaultdict.fromkeys() #115010

Closed
wants to merge 9 commits into from
9 changes: 9 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,15 @@ def test_call_dict5(x):
d2["c"] = x + 20
return d1["a"] + d2["c"] + 1

@make_test
def test_dict_fromkeys(x, y):
lst = ["a", "b"]
d = dict.fromkeys(lst)
d1 = dict.fromkeys(d, x + 1)
d2 = collections.defaultdict.fromkeys(iter(d1), x - 2)
d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1

@make_test
def test_min_max(a, b):
c = a + b
Expand Down
53 changes: 51 additions & 2 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import operator
import types
from collections import defaultdict, OrderedDict
from typing import Dict, List

import torch
Expand Down Expand Up @@ -38,7 +39,7 @@
from .base import MutableLocal, typestr, VariableTracker
from .constant import ConstantVariable
from .ctx_manager import EventVariable, StreamVariable
from .dicts import ConstDictVariable, SetVariable
from .dicts import ConstDictVariable, DefaultDictVariable, SetVariable
from .lists import (
BaseListVariable,
ListIteratorVariable,
Expand Down Expand Up @@ -661,6 +662,17 @@ def call_function(
)
return super().call_function(tx, args, kwargs)

def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if self.fn == dict and name == "fromkeys":
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
return super().call_method(tx, name, args, kwargs)

def _call_min_max(self, tx, *args):
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
# expand iterable
Expand Down Expand Up @@ -897,7 +909,44 @@ def call_custom_dict(tx, user_cls, *args, **kwargs):
return variables.ConstDictVariable(
dict(kwargs), user_cls=user_cls, mutable_local=MutableLocal()
)
unimplemented(f"dict(): {args} {kwargs}")
unimplemented(f"{user_cls.__name__}(): {args} {kwargs}")

@staticmethod
def call_custom_dict_fromkeys(tx, user_cls, *args, **kwargs):
assert user_cls in {dict, OrderedDict, defaultdict}
if kwargs:
# Only `OrderedDict.fromkeys` accepts `value` passed by keyword
assert user_cls is OrderedDict
assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs
args = (*args, kwargs.pop("value"))
if len(args) == 0:
raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0")
if len(args) == 1:
args = (*args, ConstantVariable.create(None))
assert len(args) == 2
arg, value = args
DictVariableType = (
ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
)

if isinstance(arg, dict):
return DictVariableType(
dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
)
elif isinstance(
arg,
(
ConstDictVariable,
ListVariable,
TupleVariable,
ListIteratorVariable,
),
):
keys = [DictVariableType.get_key(x) for x in arg.unpack_var_sequence(tx)]
return DictVariableType(
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
)
unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}")

def call_zip(self, tx, *args, **kwargs):
if kwargs:
Expand Down
18 changes: 18 additions & 0 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,24 @@ def wraps(fn):
msg += f"', {self.reason}'" if self.reason else ""
unimplemented(msg)

def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if (
self.value in {collections.OrderedDict, collections.defaultdict}
and name == "fromkeys"
):
from .builtin import BuiltinVariable

return BuiltinVariable.call_custom_dict_fromkeys(
tx, self.value, *args, **kwargs
)
return super().call_method(tx, name, args, kwargs)


class TypingVariable(VariableTracker):
def __init__(self, value, **kwargs):
Expand Down
Loading