-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Make variables in dict LazyTrackers (not lazily guarded yet) and avoid using DICT_KEYS guard #117625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make variables in dict LazyTrackers (not lazily guarded yet) and avoid using DICT_KEYS guard #117625
Changes from all commits
fa40a45
125bd38
fdb6e07
becb54f
f2d5fd5
4ea56b9
64c30dd
921dc04
eb1fa85
2c67f6a
6812a60
1002aca
b7a7c00
bfc393d
5149aab
61d4ea1
f38b105
10f480f
74d81fd
9786535
4be4a98
3f05363
2ecd958
267b853
5fc6f32
f0173b2
6128e6b
510baf4
f8cd51d
49ee382
01b93b0
16281ef
9931d42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,7 +95,6 @@ | |
| DataClassVariable, | ||
| DefaultDictVariable, | ||
| HFPretrainedConfigVariable, | ||
| is_hashable_python_var, | ||
| PythonSysModulesVariable, | ||
| SetVariable, | ||
| ) | ||
|
|
@@ -412,9 +411,7 @@ class Autotuner: | |
| return ConstDictVariable(result, type(value)) | ||
| elif value is sys.modules: | ||
| return PythonSysModulesVariable(source=self.source) | ||
| elif istype( | ||
| value, (dict, collections.defaultdict, collections.OrderedDict) | ||
| ) and all(is_hashable_python_var(k) for k in value.keys()): | ||
| elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): | ||
| if not value and self.get_source().is_nn_module(): | ||
| # It is faster to guard on 'false' property than to guard | ||
| # on actual dict keys, but we can't do this fast guard in general because | ||
|
|
@@ -425,26 +422,22 @@ class Autotuner: | |
| # but not completely secure job ensuring a property wasn't changed. | ||
| self.install_guards(GuardBuilder.BOOL_FALSE) | ||
| else: | ||
| self.install_guards(GuardBuilder.DICT_KEYS) | ||
| self.install_guards(GuardBuilder.LIST_LENGTH) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to delete There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's the plan, yes. I'll put up a PR next week. |
||
|
|
||
| idx = 0 | ||
|
|
||
| def build_key_value(k, v): | ||
| nonlocal idx | ||
| if ConstantVariable.is_literal(k): | ||
| key = ConstantVariable.create(k) | ||
| source_key = k | ||
| else: | ||
| source_key = ConstDictKeySource(self.get_source(), idx) | ||
| key = VariableBuilder(self.tx, source_key)(k) | ||
| # We need all the keys to be hashable. We do this within the | ||
| # _HashableTracker class in dicts.py | ||
| def build_key_value(i, k, v): | ||
| source_key = ConstDictKeySource(self.get_source(), i) | ||
| key = LazyVariableTracker.create(k, source_key) | ||
peterbell10 marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so the only substantive difference is that where before we had a single There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thereby making guard evaluation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is true that this is now O(n²), but it is also true that the previous approach was completely broken. The previous approach was alright for when we just had constant keys, but now we have everything and anything, so trying to generate a printable version of each object in DICT_KEYS is just too broken. Let's see what the benchmarks have to say about the compilation times of this approach tho. If it's an issue, I could add a pass where if all objects within the keys of the dict are sourceless then we replace all those checks by a check similar to the previous one. Another way forward would be to really implementing laziness on the keys of a dict. This would most probably offset O(n) issue on its own. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally I think we would have individual guards for each key, but the guard checking code would group them together and only iterate over the dictionary's keys once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this performs okay and it fixes a real bug then it's okay for now though I guess. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious: We are installing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The guards generated for the keys themselves will look something like: guard_0(___dict_keys_getitem(dict, 0))
...
guard_i(___dict_keys_getitem(dict, i))where eack key's guard evaluation calls If our guard codegen was smarter though, we could just iterator over the key set once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see. That makes sense. I missed that |
||
|
|
||
| source_value = GetItemSource(self.get_source(), source_key) | ||
| value = LazyVariableTracker.create(v, source_value) | ||
|
|
||
| idx += 1 | ||
| return key, value | ||
|
|
||
| result = dict(build_key_value(k, v) for k, v in value.items()) | ||
| result = dict( | ||
| build_key_value(i, k, v) for i, (k, v) in enumerate(value.items()) | ||
| ) | ||
|
|
||
| if istype(value, collections.defaultdict): | ||
| result = DefaultDictVariable( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.