-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[fx] Move map_aggregate to C++ #148243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx] Move map_aggregate to C++ #148243
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148243
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8348536 with merge base b958890 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there is any plan to reimplement map_aggregate
to reuse the pytree infra.
elif isinstance(a, list): | ||
result = immutable_list([map_aggregate(elem, fn) for elem in a]) | ||
elif isinstance(a, dict): | ||
result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One blocker to use the pytree tree_map
is that pytree uses type(a) is list
/ type(a) is dict
rather than isinstance(a, list)
/ isinstance(a, dict)
. How do we want to support user subclasses, such as collections.UserList
/ collections.UserDict
? Or we can just support list
(immutable_list
), dict
(immutable_dict
), defaultdict
, OrderedDict
only.
@XuehaiPan map_aggregate is a backwards compatibility surface for FX and the public API and behavior can't be changed. It implicitly converts all list subclasses to immutable_list. |
Starting merge as part of PR stack under #148261 |
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before: ``` 25203549 function calls (24403352 primitive calls) in 12.090 seconds ``` after: ``` 24303536 function calls (23503339 primitive calls) in 10.726 seconds ``` Pull Request resolved: #148260 Approved by: https://github.com/oulgen ghstack dependencies: #148243
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before: ``` 24303536 function calls (23503339 primitive calls) in 10.726 seconds ``` after: ``` 20003454 function calls (19203257 primitive calls) in 8.936 seconds ``` Pull Request resolved: #148261 Approved by: https://github.com/oulgen ghstack dependencies: #148243, #148260
This reverts commit edaff88. Reverted #148243 on behalf of https://github.com/jovianjaison due to breaking internal builds [T216910920] ([comment](#148243 (comment)))
@jansel your PR has been successfully reverted. |
…ytorch#148303) Pull Request resolved: pytorch#148303 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#148243, pytorch#148260, pytorch#148261
Before:  After:  Pull Request resolved: pytorch#148288 Approved by: https://github.com/oulgen ghstack dependencies: pytorch#148243, pytorch#148260, pytorch#148261, pytorch#148303
…ch#148292) Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds Pull Request resolved: pytorch#148292 Approved by: https://github.com/oulgen ghstack dependencies: pytorch#148243, pytorch#148260, pytorch#148261, pytorch#148303, pytorch#148288
Starting merge as part of PR stack under #148292 |
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before: ``` 25203549 function calls (24403352 primitive calls) in 12.090 seconds ``` after: ``` 24303536 function calls (23503339 primitive calls) in 10.726 seconds ``` Pull Request resolved: #148260 Approved by: https://github.com/oulgen ghstack dependencies: #148243
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before: ``` 24303536 function calls (23503339 primitive calls) in 10.726 seconds ``` after: ``` 20003454 function calls (19203257 primitive calls) in 8.936 seconds ``` Pull Request resolved: #148261 Approved by: https://github.com/oulgen ghstack dependencies: #148243, #148260
Before:  After:  Pull Request resolved: #148288 Approved by: https://github.com/oulgen ghstack dependencies: #148243, #148260, #148261
Stack from ghstack (oldest at bottom):
Microbenchmarking
fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))
, before:after:
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames