Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make step() faster by passing in a tensor vs scalar 1 #111084

Closed
wants to merge 9 commits into from

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Oct 11, 2023

This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to .copy_() and .to().

Code

import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))

Profiles

with old code

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms

with new code

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111084

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4bdbd95 with merge base 74f6f7a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms
```




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Oct 11, 2023
ghstack-source-id: 8e014f0dfb818105b2a2c5a20b2b6eb9df657bb6
Pull Request resolved: #111084
torch/optim/adagrad.py Outdated Show resolved Hide resolved
This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms
```




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Oct 12, 2023
ghstack-source-id: 3500ad7b2ceb60620afe2491bed7814fad5db100
Pull Request resolved: #111084
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit on style but sgtm

torch/optim/adamax.py Outdated Show resolved Hide resolved
torch/optim/adagrad.py Outdated Show resolved Hide resolved
This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms
```




[ghstack-poisoned]
This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: b7b9bface8ee1e69d746ce8a2527eb0ff6951bce
Pull Request resolved: #111084
torch/_inductor/lowering.py Outdated Show resolved Hide resolved
This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: 454df8b1fd06cf9d5b69f39b4d9db6fee7d0f282
Pull Request resolved: #111084
This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Oct 14, 2023
ghstack-source-id: c8f38a8291ef99ce2a177cd00d38ee5340a80291
Pull Request resolved: #111084
@janeyx99 janeyx99 added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 14, 2023
@@ -4733,6 +4733,7 @@ def register_pointwise_numeric_ldf64(op):

register_foreach_pointwise(aten._foreach_add.List, add, allow_alpha=True)
register_foreach_pointwise(aten._foreach_add.Scalar, add, allow_alpha=True)
register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are adding this registration can you add a test as well in test_foreach.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean for eager? If so, the test is already added in the earlier PR of this stack (where the overload is added)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I mean for inductor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #111600

This is the culminated result of #110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1  
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000  
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000  
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000  
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000  
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000  
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1  
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000  
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1  
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1  
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1  
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1  
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1  
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.751ms
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Oct 19, 2023
ghstack-source-id: 4d0c7a4b36f1b3bf3d837e02da7fe7621a5a468b
Pull Request resolved: #111084
@janeyx99
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/janeyx99/98/head branch October 23, 2023 14:24
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
This is the culminated result of pytorch#110954 (comment).

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 90.751ms
```

Pull Request resolved: pytorch#111084
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#111079
pytorchmergebot pushed a commit that referenced this pull request Jul 24, 2024
… during graph tracing (#130909)

Hey folks, I was using the `stateless_func` [here](https://github.com/pytorch/pytorch/blob/7c45476d38176c8d5b19fb379fc073dc21beba64/torch/distributed/_spmd/api.py#L435), which worked well before [this commit](#111084) but then introduced a `_tensor_constant0` and made this func non-stateless. Since there is no way to retrieve this constant tensor before compilation and performance is not an issue when tracing a graph, I think it might be good to fall back to the other branch.
![image](https://github.com/user-attachments/assets/6ee4487d-456b-47e0-8c1d-66cb5a641d47)

![image](https://github.com/user-attachments/assets/1ed46502-e50e-45c4-9751-49aa5a4590ae)

Pull Request resolved: #130909
Approved by: https://github.com/mlazos
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
… during graph tracing (pytorch#130909)

Hey folks, I was using the `stateless_func` [here](https://github.com/pytorch/pytorch/blob/7c45476d38176c8d5b19fb379fc073dc21beba64/torch/distributed/_spmd/api.py#L435), which worked well before [this commit](pytorch#111084) but then introduced a `_tensor_constant0` and made this func non-stateless. Since there is no way to retrieve this constant tensor before compilation and performance is not an issue when tracing a graph, I think it might be good to fall back to the other branch.
![image](https://github.com/user-attachments/assets/6ee4487d-456b-47e0-8c1d-66cb5a641d47)

![image](https://github.com/user-attachments/assets/1ed46502-e50e-45c4-9751-49aa5a4590ae)

Pull Request resolved: pytorch#130909
Approved by: https://github.com/mlazos
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants