Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaxPool1d without indices optimization #43745

Closed

Conversation

heitorschueroff
Copy link
Contributor

@heitorschueroff heitorschueroff commented Aug 27, 2020

Stack from ghstack:

This is part of a larger effort to refactor and optimize the pooling code. Previously I started working on MaxPool2d here #43267 but since it uses MaxPool1d as a subroutine, it made more sense to work on 1D first and get it tested and optimized and then move up to 2D and then 3D.

Below are some benchmarking results, the python script I used is under the results.

Benchmarking

Name (time in us)                            Min                   Max                Mean             StdDev              Median                 IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_googlenet[(3, 2, 0, 1, 0)-new]      79.7659 (1.03)     1,059.6327 (5.32)      90.6280 (1.01)     19.1196 (1.41)      84.2176 (1.01)       2.4289 (1.0)     1079;2818       11.0341 (0.99)       9055           1
test_googlenet[(3, 2, 0, 1, 0)-old]     505.1531 (6.55)       830.8962 (4.17)     563.4763 (6.29)     65.3974 (4.81)     538.3361 (6.43)      80.5371 (33.16)      242;99        1.7747 (0.16)       1742           1
test_googlenet[(3, 2, 0, 1, 1)-new]      80.2949 (1.04)       233.0020 (1.17)      97.6498 (1.09)     19.1228 (1.41)      89.2282 (1.07)      18.5743 (7.65)     1858;741       10.2407 (0.92)       9587           1
test_googlenet[(3, 2, 0, 1, 1)-old]     513.5350 (6.66)       977.4677 (4.91)     594.4559 (6.63)     69.9372 (5.15)     577.9080 (6.90)      79.8218 (32.86)      503;84        1.6822 (0.15)       1675           1
test_googlenet[(3, 2, 1, 1, 0)-new]      77.1061 (1.0)        199.1168 (1.0)       89.6529 (1.0)      13.5864 (1.0)       83.7557 (1.0)        7.5139 (3.09)    1419;1556       11.1541 (1.0)        7434           1
test_googlenet[(3, 2, 1, 1, 0)-old]     543.6055 (7.05)       964.5708 (4.84)     636.9867 (7.11)     84.0732 (6.19)     616.7777 (7.36)     100.4562 (41.36)      434;65        1.5699 (0.14)       1552           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_inception[(3, 2, 0, 1, 0)-new]      84.5827 (1.00)       184.2827 (1.0)       90.5438 (1.01)      9.6324 (1.0)       89.3027 (1.05)      4.5672 (1.03)      637;759       11.0444 (0.99)       6274           1
test_inception[(3, 2, 0, 1, 0)-old]     641.2268 (7.59)     1,704.8977 (9.25)     686.9383 (7.65)     57.2499 (5.94)     682.5905 (8.01)     58.3753 (13.17)       86;21        1.4557 (0.13)        802           1
test_inception[(3, 2, 0, 1, 1)-new]      84.5008 (1.0)      1,093.6335 (5.93)      89.8233 (1.0)      14.0443 (1.46)      85.2682 (1.0)       4.4331 (1.0)      802;1106       11.1330 (1.0)        9190           1
test_inception[(3, 2, 0, 1, 1)-old]     643.7078 (7.62)       851.4188 (4.62)     687.4905 (7.65)     41.1116 (4.27)     685.1386 (8.04)     60.2733 (13.60)      286;14        1.4546 (0.13)       1300           1
test_inception[(3, 2, 1, 1, 0)-new]     106.0739 (1.26)       258.5649 (1.40)     115.3597 (1.28)     17.5436 (1.82)     106.9643 (1.25)      5.5470 (1.25)     894;1402        8.6685 (0.78)       7635           1
test_inception[(3, 2, 1, 1, 0)-old]     651.0504 (7.70)       955.2278 (5.18)     698.0295 (7.77)     45.5097 (4.72)     692.8109 (8.13)     64.6794 (14.59)      145;15        1.4326 (0.13)        909           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_batch_size[new]       2.9608 (1.0)        5.1127 (1.0)        3.3096 (1.0)      0.1936 (1.0)        3.3131 (1.0)      0.2093 (1.0)          71;6  302.1515 (1.0)         297           1
test_large_batch_size[old]     130.6583 (44.13)    152.9521 (29.92)    137.1385 (41.44)    7.4352 (38.40)    135.1784 (40.80)    5.1358 (24.53)         1;1    7.2919 (0.02)          7           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_channel_size[new]      2.9696 (1.0)       5.5595 (1.0)       3.5997 (1.0)      0.5836 (1.0)       3.3497 (1.0)      0.3445 (1.0)         58;54  277.8014 (1.0)         277           1
test_large_channel_size[old]     19.6838 (6.63)     22.6637 (4.08)     21.1775 (5.88)     0.8610 (1.48)     21.3739 (6.38)     1.4930 (4.33)         13;0   47.2199 (0.17)         36           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_width[new]      1.7714 (1.0)       2.4104 (1.0)       1.8988 (1.0)      0.0767 (1.0)       1.8911 (1.0)      0.0885 (1.0)         86;13  526.6454 (1.0)         373           1
test_large_width[old]     19.5708 (11.05)    22.8755 (9.49)     20.7987 (10.95)    0.7009 (9.14)     20.6623 (10.93)    0.8584 (9.70)         14;1   48.0799 (0.09)         46           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_multithreaded[new]      15.0560 (1.0)       24.2891 (1.0)       16.1627 (1.0)      1.5657 (1.0)       15.7182 (1.0)      0.7598 (1.0)           4;6  61.8709 (1.0)          65           1
test_multithreaded[old]     115.7614 (7.69)     120.9670 (4.98)     118.3004 (7.32)     1.6259 (1.04)     118.4164 (7.53)     1.9613 (2.58)          2;0   8.4531 (0.14)          8           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean

Benchmarking script

To run the benchmark make sure you have pytest-benchmark installed with pip install pytest-benchmark and use the following command: pytest benchmark.py --benchmark-sort='name'

import torch
import pytest


def _test_speedup(benchmark, batches=1, channels=32, width=32,
                  kernel_size=2, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
    torch.set_num_threads(1)
    x = torch.randn((batches, channels, width))
    model = torch.nn.MaxPool1d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
    benchmark(model, x)


@pytest.mark.benchmark(group="inception")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
@pytest.mark.parametrize("params", [(3, 2), (3, 2, 0, 1, True), (3, 2, 1)],
                         ids=["(3, 2, 0, 1, 0)",
                              "(3, 2, 0, 1, 1)",
                              "(3, 2, 1, 1, 0)"])
def test_inception(benchmark, params, return_indices):
    _test_speedup(benchmark, 10, 64, 147, *params, return_indices=return_indices)


@pytest.mark.benchmark(group="googlenet")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
@pytest.mark.parametrize("params", [(3, 2), (3, 2, 0, 1, True), (3, 2, 1)],
                         ids=["(3, 2, 0, 1, 0)",
                              "(3, 2, 0, 1, 1)",
                              "(3, 2, 1, 1, 0)"])
def test_googlenet(benchmark, params, return_indices):
    _test_speedup(benchmark, 10, 64, 112, *params, return_indices=return_indices)


@pytest.mark.benchmark(group="large batch size")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_batch_size(benchmark, return_indices):
    _test_speedup(benchmark, 100000, 1, 32, return_indices=return_indices)


@pytest.mark.benchmark(group="large channel size")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_channel_size(benchmark, return_indices):
    _test_speedup(benchmark, 1, 100000, 32, return_indices=return_indices)


@pytest.mark.benchmark(group="large width")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_width(benchmark, return_indices):
    _test_speedup(benchmark, 1, 32, 100000, return_indices=return_indices)


@pytest.mark.benchmark(group="multithreading")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_multithreaded(benchmark, return_indices):
    x = torch.randn((40, 10000, 32))
    model = torch.nn.MaxPool1d(2, return_indices=return_indices)
    benchmark(model, x)

Discussion

The new algorithm is on average 7x faster than the old one. But because the old algorithm had many issues with how it parallelized the code and made use of the cache, one can come up with input parameters (like large batch size) that will make the new algorithm much faster than the original one.

Differential Revision: D23425348

@dr-ci
Copy link

dr-ci bot commented Aug 27, 2020

💊 CI failures summary and remediations

As of commit 8896c9d (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 17 times.

This is part of a larger effort to refactor and optimize the pooling code. Previously I started working on MaxPool2d here #43267 but since it uses MaxPool1d as a subroutine, it made more sense to work on 1D first and get it tested and optimized and then move up to 2D and then 3D.

TODO: I'll add some bigger tests and some early benchmarking code and results here.

[ghstack-poisoned]
heitorschueroff added a commit that referenced this pull request Aug 27, 2020
ghstack-source-id: 5b4a558280f2f3ca2988702b0e872d235240e35f
Pull Request resolved: #43745
This is part of a larger effort to refactor and optimize the pooling code. Previously I started working on MaxPool2d here #43267 but since it uses MaxPool1d as a subroutine, it made more sense to work on 1D first and get it tested and optimized and then move up to 2D and then 3D.

TODO: I'll add some bigger tests and some early benchmarking code and results here.

[ghstack-poisoned]
heitorschueroff added a commit that referenced this pull request Aug 28, 2020
ghstack-source-id: f25c08c3dd3fae792b18db1b6fb32ec5c7106b6c
Pull Request resolved: #43745
Copy link
Contributor

@glaringlee glaringlee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is generally looks good to me. Easy to expand, slight commented.

aten/src/ATen/native/cpu/MaxPooling.cpp Show resolved Hide resolved
aten/src/ATen/native/MaxPooling.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/MaxPooling.h Outdated Show resolved Hide resolved
aten/src/ATen/native/MaxPooling.h Outdated Show resolved Hide resolved
This is part of a larger effort to refactor and optimize the pooling code. Previously I started working on MaxPool2d here #43267 but since it uses MaxPool1d as a subroutine, it made more sense to work on 1D first and get it tested and optimized and then move up to 2D and then 3D.

TODO: I'll add some bigger tests and some early benchmarking code and results here.

[ghstack-poisoned]
heitorschueroff added a commit that referenced this pull request Aug 28, 2020
ghstack-source-id: e5024e815b4fe591fa7f78bdfeaba8e5e9207c1f
Pull Request resolved: #43745
@glaringlee
Copy link
Contributor

I am fine with this change now.

This is part of a larger effort to refactor and optimize the pooling code. Previously I started working on MaxPool2d here #43267 but since it uses MaxPool1d as a subroutine, it made more sense to work on 1D first and get it tested and optimized and then move up to 2D and then 3D.

TODO: I'll add some bigger tests and some early benchmarking code and results here.

[ghstack-poisoned]
heitorschueroff added a commit that referenced this pull request Aug 29, 2020
ghstack-source-id: 5924c6eaad0a9583f6de91728e3ee5c069463530
Pull Request resolved: #43745
@heitorschueroff heitorschueroff changed the title New MaxPool1d without indices implementation New MaxPool1d without indices optimization Aug 29, 2020
@heitorschueroff heitorschueroff changed the title New MaxPool1d without indices optimization MaxPool1d without indices optimization Aug 29, 2020
@codecov
Copy link

codecov bot commented Aug 29, 2020

Codecov Report

Merging #43745 into gh/heitorschueroff/6/base will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@                    Coverage Diff                     @@
##           gh/heitorschueroff/6/base   #43745   +/-   ##
==========================================================
  Coverage                      69.32%   69.32%           
==========================================================
  Files                            378      378           
  Lines                          46745    46745           
==========================================================
  Hits                           32404    32404           
  Misses                         14341    14341           
Impacted Files Coverage Δ
torch/nn/modules/pooling.py 97.55% <ø> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5021ec8...8896c9d. Read the comment docs.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Nice benchmark numbers.

This is part of a larger effort to refactor and optimize the pooling code. Previously I started working on MaxPool2d here #43267 but since it uses MaxPool1d as a subroutine, it made more sense to work on 1D first and get it tested and optimized and then move up to 2D and then 3D.

Below are some benchmarking results, the python script I used is under the results.

## Benchmarking
```
Name (time in us)                            Min                   Max                Mean             StdDev              Median                 IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_googlenet[(3, 2, 0, 1, 0)-new]      79.7659 (1.03)     1,059.6327 (5.32)      90.6280 (1.01)     19.1196 (1.41)      84.2176 (1.01)       2.4289 (1.0)     1079;2818       11.0341 (0.99)       9055           1
test_googlenet[(3, 2, 0, 1, 0)-old]     505.1531 (6.55)       830.8962 (4.17)     563.4763 (6.29)     65.3974 (4.81)     538.3361 (6.43)      80.5371 (33.16)      242;99        1.7747 (0.16)       1742           1
test_googlenet[(3, 2, 0, 1, 1)-new]      80.2949 (1.04)       233.0020 (1.17)      97.6498 (1.09)     19.1228 (1.41)      89.2282 (1.07)      18.5743 (7.65)     1858;741       10.2407 (0.92)       9587           1
test_googlenet[(3, 2, 0, 1, 1)-old]     513.5350 (6.66)       977.4677 (4.91)     594.4559 (6.63)     69.9372 (5.15)     577.9080 (6.90)      79.8218 (32.86)      503;84        1.6822 (0.15)       1675           1
test_googlenet[(3, 2, 1, 1, 0)-new]      77.1061 (1.0)        199.1168 (1.0)       89.6529 (1.0)      13.5864 (1.0)       83.7557 (1.0)        7.5139 (3.09)    1419;1556       11.1541 (1.0)        7434           1
test_googlenet[(3, 2, 1, 1, 0)-old]     543.6055 (7.05)       964.5708 (4.84)     636.9867 (7.11)     84.0732 (6.19)     616.7777 (7.36)     100.4562 (41.36)      434;65        1.5699 (0.14)       1552           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_inception[(3, 2, 0, 1, 0)-new]      84.5827 (1.00)       184.2827 (1.0)       90.5438 (1.01)      9.6324 (1.0)       89.3027 (1.05)      4.5672 (1.03)      637;759       11.0444 (0.99)       6274           1
test_inception[(3, 2, 0, 1, 0)-old]     641.2268 (7.59)     1,704.8977 (9.25)     686.9383 (7.65)     57.2499 (5.94)     682.5905 (8.01)     58.3753 (13.17)       86;21        1.4557 (0.13)        802           1
test_inception[(3, 2, 0, 1, 1)-new]      84.5008 (1.0)      1,093.6335 (5.93)      89.8233 (1.0)      14.0443 (1.46)      85.2682 (1.0)       4.4331 (1.0)      802;1106       11.1330 (1.0)        9190           1
test_inception[(3, 2, 0, 1, 1)-old]     643.7078 (7.62)       851.4188 (4.62)     687.4905 (7.65)     41.1116 (4.27)     685.1386 (8.04)     60.2733 (13.60)      286;14        1.4546 (0.13)       1300           1
test_inception[(3, 2, 1, 1, 0)-new]     106.0739 (1.26)       258.5649 (1.40)     115.3597 (1.28)     17.5436 (1.82)     106.9643 (1.25)      5.5470 (1.25)     894;1402        8.6685 (0.78)       7635           1
test_inception[(3, 2, 1, 1, 0)-old]     651.0504 (7.70)       955.2278 (5.18)     698.0295 (7.77)     45.5097 (4.72)     692.8109 (8.13)     64.6794 (14.59)      145;15        1.4326 (0.13)        909           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_batch_size[new]       2.9608 (1.0)        5.1127 (1.0)        3.3096 (1.0)      0.1936 (1.0)        3.3131 (1.0)      0.2093 (1.0)          71;6  302.1515 (1.0)         297           1
test_large_batch_size[old]     130.6583 (44.13)    152.9521 (29.92)    137.1385 (41.44)    7.4352 (38.40)    135.1784 (40.80)    5.1358 (24.53)         1;1    7.2919 (0.02)          7           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_channel_size[new]      2.9696 (1.0)       5.5595 (1.0)       3.5997 (1.0)      0.5836 (1.0)       3.3497 (1.0)      0.3445 (1.0)         58;54  277.8014 (1.0)         277           1
test_large_channel_size[old]     19.6838 (6.63)     22.6637 (4.08)     21.1775 (5.88)     0.8610 (1.48)     21.3739 (6.38)     1.4930 (4.33)         13;0   47.2199 (0.17)         36           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_width[new]      1.7714 (1.0)       2.4104 (1.0)       1.8988 (1.0)      0.0767 (1.0)       1.8911 (1.0)      0.0885 (1.0)         86;13  526.6454 (1.0)         373           1
test_large_width[old]     19.5708 (11.05)    22.8755 (9.49)     20.7987 (10.95)    0.7009 (9.14)     20.6623 (10.93)    0.8584 (9.70)         14;1   48.0799 (0.09)         46           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_multithreaded[new]      15.0560 (1.0)       24.2891 (1.0)       16.1627 (1.0)      1.5657 (1.0)       15.7182 (1.0)      0.7598 (1.0)           4;6  61.8709 (1.0)          65           1
test_multithreaded[old]     115.7614 (7.69)     120.9670 (4.98)     118.3004 (7.32)     1.6259 (1.04)     118.4164 (7.53)     1.9613 (2.58)          2;0   8.4531 (0.14)          8           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean
```

### Benchmarking script
To run the benchmark make sure you have pytest-benchmark installed with `pip install pytest-benchmark` and use the following command: `pytest benchmark.py --benchmark-sort='name'`

```
import torch
import pytest


def _test_speedup(benchmark, batches=1, channels=32, width=32,
                  kernel_size=2, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
    torch.set_num_threads(1)
    x = torch.randn((batches, channels, width))
    model = torch.nn.MaxPool1d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
    benchmark(model, x)


@pytest.mark.benchmark(group="inception")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
@pytest.mark.parametrize("params", [(3, 2), (3, 2, 0, 1, True), (3, 2, 1)],
                         ids=["(3, 2, 0, 1, 0)",
                              "(3, 2, 0, 1, 1)",
                              "(3, 2, 1, 1, 0)"])
def test_inception(benchmark, params, return_indices):
    _test_speedup(benchmark, 10, 64, 147, *params, return_indices=return_indices)


@pytest.mark.benchmark(group="googlenet")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
@pytest.mark.parametrize("params", [(3, 2), (3, 2, 0, 1, True), (3, 2, 1)],
                         ids=["(3, 2, 0, 1, 0)",
                              "(3, 2, 0, 1, 1)",
                              "(3, 2, 1, 1, 0)"])
def test_googlenet(benchmark, params, return_indices):
    _test_speedup(benchmark, 10, 64, 112, *params, return_indices=return_indices)


@pytest.mark.benchmark(group="large batch size")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_batch_size(benchmark, return_indices):
    _test_speedup(benchmark, 100000, 1, 32, return_indices=return_indices)


@pytest.mark.benchmark(group="large channel size")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_channel_size(benchmark, return_indices):
    _test_speedup(benchmark, 1, 100000, 32, return_indices=return_indices)


@pytest.mark.benchmark(group="large width")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_width(benchmark, return_indices):
    _test_speedup(benchmark, 1, 32, 100000, return_indices=return_indices)


@pytest.mark.benchmark(group="multithreading")
@pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_multithreaded(benchmark, return_indices):
    x = torch.randn((40, 10000, 32))
    model = torch.nn.MaxPool1d(2, return_indices=return_indices)
    benchmark(model, x)
```

## Discussion

The new algorithm is on average 7x faster than the old one. But because the old algorithm had many issues with how it parallelized the code and made use of the cache, one can come up with input parameters (like large batch size) that will make the new algorithm much faster than the original one.

[ghstack-poisoned]
heitorschueroff added a commit that referenced this pull request Aug 31, 2020
ghstack-source-id: 58436789e92c98b18f0846b3fd72d2adadea563c
Pull Request resolved: #43745
@facebook-github-bot
Copy link
Contributor

@heitorschueroff merged this pull request in 13a48ac.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any single-threaded benchmarks available? I suspect most of the speed-up is due to parallelization, and inference is usually single-threaded.

IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
if (self.requires_grad() || !self.device().is_cpu()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also check if GradMode is enabled, otherwise even if self required grad, indices are not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update this in the next PR for max_pool1d.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this get addressed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#46767. I was going to address it with the optimizations to with_indices and backwards but then switched focus and forgot this. Thanks for reminding.

scalar_t* C10_RESTRICT op,
const scalar_t* C10_RESTRICT ip,
const PoolingParams1D& p) {
for (int64_t kj = 0; kj < p.KW; ++kj) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check that this actually benefits from auto vectorization? If it does not, it should not be in the cpu folder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I did, it does benefit especially when using C10_RESTRICT which can only be used in cpu folder.

int64_t ij = p.index(kj, oj);
for (; oj < oe; ++oj, ij += p.SJ) {
scalar_t val = ip[ij];
bool update_max = std::isnan(val) || op[oj] < val;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

way to blow the cache - instead of computing max in registers within a window, you are constantly accessing full output line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, and I have plans for a future PR to compute output in blocks that fit within the cache. For now I wanted to keep the code simple so I can have confidence it works. This way the compiler auto vectorized the loop for me.

@heitorschueroff
Copy link
Contributor Author

Are there any single-threaded benchmarks available? I suspect most of the speed-up is due to parallelization, and inference is usually single-threaded.

All benchmark results are single-threaded except the one name test_multithreaded.

@facebook-github-bot facebook-github-bot deleted the gh/heitorschueroff/6/head branch September 5, 2020 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants