Skip to content

Conversation

peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Oct 30, 2023

Stack from ghstack (oldest at bottom):

We commonly do some variation of tree_leaves((args, kwargs)). This adds a new
function arg_tree_leaves(*args, **kwargs) which takes advantage of the known
structure of args and kwargs to skip their flatten_fn.

I see ~1 us improvement per call for args + kwargs, or a 0.5 us improvement
when passing just one of args or kwargs. For shallow structures, this can be
proportionally quite significant. For example, the empty_strided call I've been
using as a benchmark:

args = ((100, 100), (100, 1))
kwargs = dict(device="cuda")

Sees a 30% speedup from this.

cc @zou3519

We commonly do some variation of `tree_leaves((args, kwargs))`. This adds a new
function `arg_tree_leaves(*args, **kwargs)` which takes advantage of the known
structure of `args` and `kwargs` to skip their `flatten_fn`.

I see ~1 us improvement per call for args + kwargs, or a 0.5 us improvement
when passing just one of `args` or `kwargs`. For shallow structures, this can be
proportionally quite significant. For example, the empty_strided call I've been
using as a benchmark:
```
args = ((100, 100), (100, 1))
kwargs = dict(device="cuda")
```
Sees a 30% speedup from this.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 30, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112393

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ca986a8 with merge base 29844ad (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…on arguments"

We commonly do some variation of `tree_leaves((args, kwargs))`. This adds a new
function `arg_tree_leaves(*args, **kwargs)` which takes advantage of the known
structure of `args` and `kwargs` to skip their `flatten_fn`.

I see ~1 us improvement per call for args + kwargs, or a 0.5 us improvement
when passing just one of `args` or `kwargs`. For shallow structures, this can be
proportionally quite significant. For example, the empty_strided call I've been
using as a benchmark:
```
args = ((100, 100), (100, 1))
kwargs = dict(device="cuda")
```
Sees a 30% speedup from this.

cc zou3519

[ghstack-poisoned]
@peterbell10 peterbell10 marked this pull request as ready for review October 30, 2023 17:52
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Oct 31, 2023
We commonly do some variation of `tree_leaves((args, kwargs))`. This adds a new
function `arg_tree_leaves(*args, **kwargs)` which takes advantage of the known
structure of `args` and `kwargs` to skip their `flatten_fn`.

I see ~1 us improvement per call for args + kwargs, or a 0.5 us improvement
when passing just one of `args` or `kwargs`. For shallow structures, this can be
proportionally quite significant. For example, the empty_strided call I've been
using as a benchmark:
```
args = ((100, 100), (100, 1))
kwargs = dict(device="cuda")
```
Sees a 30% speedup from this.

ghstack-source-id: 3f5310d
Pull Request resolved: pytorch#112393
pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2023
Pull Request resolved: #112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2023
Wherever we discard the output of `tree_map` it's better to call `tree_map_`
which doesn't unflatten the mapped results and so is a lot cheaper.
Pull Request resolved: #112417
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393, #112394
@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/648/head branch November 4, 2023 14:26
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
pytorch#112393)

We commonly do some variation of `tree_leaves((args, kwargs))`. This adds a new
function `arg_tree_leaves(*args, **kwargs)` which takes advantage of the known
structure of `args` and `kwargs` to skip their `flatten_fn`.

I see ~1 us improvement per call for args + kwargs, or a 0.5 us improvement
when passing just one of `args` or `kwargs`. For shallow structures, this can be
proportionally quite significant. For example, the empty_strided call I've been
using as a benchmark:
```
args = ((100, 100), (100, 1))
kwargs = dict(device="cuda")
```
Sees a 30% speedup from this.

Pull Request resolved: pytorch#112393
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#112391, pytorch#112392
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Wherever we discard the output of `tree_map` it's better to call `tree_map_`
which doesn't unflatten the mapped results and so is a lot cheaper.
Pull Request resolved: pytorch#112417
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#112391, pytorch#112392, pytorch#112393, pytorch#112394
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
pytorch#112393)

We commonly do some variation of `tree_leaves((args, kwargs))`. This adds a new
function `arg_tree_leaves(*args, **kwargs)` which takes advantage of the known
structure of `args` and `kwargs` to skip their `flatten_fn`.

I see ~1 us improvement per call for args + kwargs, or a 0.5 us improvement
when passing just one of `args` or `kwargs`. For shallow structures, this can be
proportionally quite significant. For example, the empty_strided call I've been
using as a benchmark:
```
args = ((100, 100), (100, 1))
kwargs = dict(device="cuda")
```
Sees a 30% speedup from this.

Pull Request resolved: pytorch#112393
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#112391, pytorch#112392
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Wherever we discard the output of `tree_map` it's better to call `tree_map_`
which doesn't unflatten the mapped results and so is a lot cheaper.
Pull Request resolved: pytorch#112417
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#112391, pytorch#112392, pytorch#112393, pytorch#112394
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants