-
Notifications
You must be signed in to change notification settings - Fork 24.6k
Don't implicitly convert to channels-first in MaxPool3D on CUDA #80748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 2551fea (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
4b65a00
to
db94c1d
Compare
@@ -263,45 +347,62 @@ void max_pool3d_with_indices_out_cuda_template( | |||
otime, oheight, owidth, | |||
"max_pool3d_with_indices_out_cuda_template()"); | |||
|
|||
bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d; | |||
|
|||
if (input.ndimension() == 4) { | |||
output.resize_({ nslices, otime, oheight, owidth}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't you also treat 4d tensors in the same way (by unsqueezing them to have batch dimension first, and treating them as channels-last if appropriate?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, although in this process I encountered some suspicious behavior from suggest_memory_format
where the channels-last batch size 1 input after going through squeeze and unsqueeze was suggested as contiguous
even though is_contiguous
was false
and is_contiguous(at::MemoryFormat::ChannelsLast3D)
was true
after the unsqueeze.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @jjsjann123 for suggest_memory_format
behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving, but lint is real, please fix and ping me when tests pass (or probably you can land yourself?)
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
Hey @eqy. |
…) (#80748) Summary: MaxPool3D currently converts inputs implicitly to channels-first (via `.contiguous()`) which may yield unexpected regressions in workloads that expect a full channels-last path. This PR preserves the channels-last format in MaxPool3D while attempting to avoid seriously regressing performance. Currently, typical case (kernel size == 2 == stride) looks good, but larger kernel sizes (>4) or the unusual case of stride 1 can sometimes be slower than converting to channels-first before doing MaxPool3D. Additionally, this PR adds a test for 64bit-indexing backwards as testing of these changes uncovered an IMA for large tensors when doing the backwards pass with MaxPool3D. Performance comparison on A6000: ``` [------------------------------------- max_pool3d ---------------------------------------------------------] | channels_last=False | curr ch_last=True | new ch_last=True 1 threads: ---------------------------------------------------------------------------- --------------------- [64, 256, 32, 32, 32] 4x4 stride 4 | 20093.5 | 34823.4 | 20640.0 [64, 256, 32, 32, 32] 4x4 stride 2 | 28623.7 | 42625.6 | 27935.5 [64, 256, 32, 32, 32] 4x4 stride 1 | 68177.5 | 79147.2 | 85604.8 [64, 256, 32, 32, 32] 2x2 stride 4 | 17237.7 | 32071.3 | 16641.6 [64, 256, 32, 32, 32] 2x2 stride 2 | 25252.5 | 39993.2 | 25054.8 [64, 256, 32, 32, 32] 2x2 stride 1 | 43185.2 | 58164.6 | 48416.9 [64, 256, 16, 16, 16] 4x4 stride 4 | 3017.7 | 3952.4 | 2593.8 [64, 256, 16, 16, 16] 4x4 stride 2 | 4581.5 | 5384.3 | 3294.3 [64, 256, 16, 16, 16] 4x4 stride 1 | 11334.1 | 11534.7 | 8651.1 [64, 256, 16, 16, 16] 2x2 stride 4 | 2346.9 | 3304.6 | 2098.8 [64, 256, 16, 16, 16] 2x2 stride 2 | 3550.8 | 4526.5 | 3143.6 [64, 256, 16, 16, 16] 2x2 stride 1 | 6898.1 | 7816.0 | 5820.8 [64, 256, 4, 4, 4] 4x4 stride 4 | 191.5 | 176.3 | 77.5 [64, 256, 4, 4, 4] 4x4 stride 2 | 191.8 | 176.8 | 94.1 [64, 256, 4, 4, 4] 4x4 stride 1 | 191.3 | 176.4 | 97.3 [64, 256, 4, 4, 4] 2x2 stride 4 | 96.4 | 114.4 | 93.6 [64, 256, 4, 4, 4] 2x2 stride 2 | 172.1 | 178.6 | 93.7 [64, 256, 4, 4, 4] 2x2 stride 1 | 263.0 | 279.4 | 92.4 [64, 64, 32, 32, 32] 4x4 stride 4 | 5033.2 | 7208.3 | 5167.5 [64, 64, 32, 32, 32] 4x4 stride 2 | 7216.1 | 9218.7 | 6637.1 [64, 64, 32, 32, 32] 4x4 stride 1 | 17192.1 | 18392.9 | 20489.0 [64, 64, 32, 32, 32] 2x2 stride 4 | 4318.0 | 6511.2 | 4193.1 [64, 64, 32, 32, 32] 2x2 stride 2 | 6324.4 | 8657.7 | 6263.6 [64, 64, 32, 32, 32] 2x2 stride 1 | 10855.0 | 13040.2 | 12055.9 [64, 64, 16, 16, 16] 4x4 stride 4 | 764.1 | 975.6 | 671.3 [64, 64, 16, 16, 16] 4x4 stride 2 | 1163.1 | 1333.4 | 833.6 [64, 64, 16, 16, 16] 4x4 stride 1 | 2890.0 | 2898.5 | 2209.8 [64, 64, 16, 16, 16] 2x2 stride 4 | 593.5 | 811.2 | 536.3 [64, 64, 16, 16, 16] 2x2 stride 2 | 895.9 | 1112.3 | 794.5 [64, 64, 16, 16, 16] 2x2 stride 1 | 1742.5 | 1968.0 | 1475.2 [64, 64, 4, 4, 4] 4x4 stride 4 | 101.1 | 112.2 | 93.4 [64, 64, 4, 4, 4] 4x4 stride 2 | 96.7 | 114.6 | 92.5 [64, 64, 4, 4, 4] 4x4 stride 1 | 98.9 | 111.9 | 96.5 [64, 64, 4, 4, 4] 2x2 stride 4 | 100.1 | 107.1 | 94.2 [64, 64, 4, 4, 4] 2x2 stride 2 | 96.6 | 108.0 | 94.5 [64, 64, 4, 4, 4] 2x2 stride 1 | 96.7 | 107.9 | 95.2 [64, 3, 32, 32, 32] 4x4 stride 4 | 250.1 | 326.6 | 278.0 [64, 3, 32, 32, 32] 4x4 stride 2 | 350.4 | 414.0 | 323.2 [64, 3, 32, 32, 32] 4x4 stride 1 | 825.6 | 846.9 | 982.5 [64, 3, 32, 32, 32] 2x2 stride 4 | 213.3 | 289.8 | 219.9 [64, 3, 32, 32, 32] 2x2 stride 2 | 308.2 | 384.9 | 305.9 [64, 3, 32, 32, 32] 2x2 stride 1 | 523.5 | 594.7 | 589.9 [64, 3, 16, 16, 16] 4x4 stride 4 | 103.8 | 116.7 | 93.0 [64, 3, 16, 16, 16] 4x4 stride 2 | 100.9 | 108.3 | 93.3 [64, 3, 16, 16, 16] 4x4 stride 1 | 139.4 | 140.7 | 104.8 [64, 3, 16, 16, 16] 2x2 stride 4 | 97.5 | 114.7 | 92.7 [64, 3, 16, 16, 16] 2x2 stride 2 | 97.4 | 108.8 | 91.7 [64, 3, 16, 16, 16] 2x2 stride 1 | 99.9 | 108.0 | 94.1 [64, 3, 4, 4, 4] 4x4 stride 4 | 97.2 | 110.2 | 94.7 [64, 3, 4, 4, 4] 4x4 stride 2 | 105.7 | 107.4 | 92.8 [64, 3, 4, 4, 4] 4x4 stride 1 | 98.0 | 110.0 | 93.7 [64, 3, 4, 4, 4] 2x2 stride 4 | 98.3 | 116.7 | 93.0 [64, 3, 4, 4, 4] 2x2 stride 2 | 98.6 | 107.5 | 92.8 [64, 3, 4, 4, 4] 2x2 stride 1 | 100.6 | 110.3 | 94.0 [16, 256, 32, 32, 32] 4x4 stride 4 | 5034.2 | 8838.0 | 5165.9 [16, 256, 32, 32, 32] 4x4 stride 2 | 7236.3 | 10869.9 | 7038.2 [16, 256, 32, 32, 32] 4x4 stride 1 | 17385.4 | 21401.6 | 21900.7 [16, 256, 32, 32, 32] 2x2 stride 4 | 4318.7 | 8101.2 | 4172.9 [16, 256, 32, 32, 32] 2x2 stride 2 | 6324.0 | 10147.5 | 6279.7 [16, 256, 32, 32, 32] 2x2 stride 1 | 10899.7 | 14826.0 | 12256.3 [16, 256, 16, 16, 16] 4x4 stride 4 | 765.4 | 1012.7 | 675.6 [16, 256, 16, 16, 16] 4x4 stride 2 | 1162.8 | 1376.9 | 843.4 [16, 256, 16, 16, 16] 4x4 stride 1 | 2928.9 | 2969.8 | 2222.5 [16, 256, 16, 16, 16] 2x2 stride 4 | 593.5 | 845.8 | 534.2 [16, 256, 16, 16, 16] 2x2 stride 2 | 896.9 | 1152.2 | 796.9 [16, 256, 16, 16, 16] 2x2 stride 1 | 1750.2 | 2009.4 | 1481.8 [16, 256, 4, 4, 4] 4x4 stride 4 | 96.6 | 107.1 | 92.7 [16, 256, 4, 4, 4] 4x4 stride 2 | 97.9 | 114.9 | 93.8 [16, 256, 4, 4, 4] 4x4 stride 1 | 98.2 | 115.6 | 94.0 [16, 256, 4, 4, 4] 2x2 stride 4 | 97.0 | 106.7 | 93.8 [16, 256, 4, 4, 4] 2x2 stride 2 | 96.8 | 108.1 | 93.3 [16, 256, 4, 4, 4] 2x2 stride 1 | 95.8 | 120.9 | 95.7 [16, 64, 32, 32, 32] 4x4 stride 4 | 1266.4 | 1815.4 | 1312.3 [16, 64, 32, 32, 32] 4x4 stride 2 | 1818.5 | 2328.0 | 1678.9 [16, 64, 32, 32, 32] 4x4 stride 1 | 4352.9 | 4649.3 | 5204.6 [16, 64, 32, 32, 32] 2x2 stride 4 | 1090.0 | 1631.2 | 1060.8 [16, 64, 32, 32, 32] 2x2 stride 2 | 1589.4 | 2141.1 | 1576.4 [16, 64, 32, 32, 32] 2x2 stride 1 | 2733.5 | 3286.0 | 3041.6 [16, 64, 16, 16, 16] 4x4 stride 4 | 201.7 | 259.6 | 175.0 [16, 64, 16, 16, 16] 4x4 stride 2 | 301.0 | 350.1 | 226.3 [16, 64, 16, 16, 16] 4x4 stride 1 | 740.1 | 748.7 | 570.6 [16, 64, 16, 16, 16] 2x2 stride 4 | 156.0 | 214.8 | 140.8 [16, 64, 16, 16, 16] 2x2 stride 2 | 232.3 | 292.3 | 208.7 [16, 64, 16, 16, 16] 2x2 stride 1 | 449.1 | 504.0 | 382.1 [16, 64, 4, 4, 4] 4x4 stride 4 | 97.5 | 111.4 | 94.5 [16, 64, 4, 4, 4] 4x4 stride 2 | 98.8 | 111.9 | 94.4 [16, 64, 4, 4, 4] 4x4 stride 1 | 98.2 | 112.0 | 95.2 [16, 64, 4, 4, 4] 2x2 stride 4 | 99.7 | 111.0 | 94.0 [16, 64, 4, 4, 4] 2x2 stride 2 | 100.3 | 110.0 | 93.2 [16, 64, 4, 4, 4] 2x2 stride 1 | 97.5 | 107.6 | 93.5 [16, 3, 32, 32, 32] 4x4 stride 4 | 100.5 | 117.1 | 95.7 [16, 3, 32, 32, 32] 4x4 stride 2 | 97.5 | 121.3 | 92.5 [16, 3, 32, 32, 32] 4x4 stride 1 | 216.0 | 227.4 | 258.4 [16, 3, 32, 32, 32] 2x2 stride 4 | 97.1 | 109.0 | 91.9 [16, 3, 32, 32, 32] 2x2 stride 2 | 95.8 | 108.5 | 92.9 [16, 3, 32, 32, 32] 2x2 stride 1 | 139.4 | 161.2 | 157.8 [16, 3, 16, 16, 16] 4x4 stride 4 | 96.4 | 113.6 | 91.9 [16, 3, 16, 16, 16] 4x4 stride 2 | 97.4 | 108.1 | 93.5 [16, 3, 16, 16, 16] 4x4 stride 1 | 99.0 | 107.5 | 92.1 [16, 3, 16, 16, 16] 2x2 stride 4 | 96.9 | 118.1 | 93.4 [16, 3, 16, 16, 16] 2x2 stride 2 | 97.3 | 106.7 | 95.8 [16, 3, 16, 16, 16] 2x2 stride 1 | 98.8 | 109.2 | 93.8 [16, 3, 4, 4, 4] 4x4 stride 4 | 97.8 | 108.0 | 94.2 [16, 3, 4, 4, 4] 4x4 stride 2 | 92.7 | 108.0 | 93.9 [16, 3, 4, 4, 4] 4x4 stride 1 | 97.8 | 107.6 | 93.5 [16, 3, 4, 4, 4] 2x2 stride 4 | 100.3 | 107.7 | 94.3 [16, 3, 4, 4, 4] 2x2 stride 2 | 97.2 | 107.5 | 96.1 [16, 3, 4, 4, 4] 2x2 stride 1 | 98.1 | 111.1 | 93.8 Times are in microseconds (us). ``` Performance comparison on V100: (these times have been updated after working around some noisy measurements in my setup) ``` [------------------------------------- max_pool3d ---------------------------------------------------------] | channels_last=False | curr ch_last=True | new ch_last=True 1 threads: ------------------------------------------------------------------------------------------------- [64, 256, 32, 32, 32] 4x4 stride 4 | 15810.7 | 33807.7 | 16452.9 [64, 256, 32, 32, 32] 4x4 stride 2 | 24422.7 | 42515.3 | 27700.3 [64, 256, 32, 32, 32] 4x4 stride 1 | 71756.0 | 89916.5 | 106464.0 [64, 256, 32, 32, 32] 2x2 stride 4 | 12102.9 | 30210.4 | 11319.8 [64, 256, 32, 32, 32] 2x2 stride 2 | 19101.7 | 37210.8 | 20373.3 [64, 256, 32, 32, 32] 2x2 stride 1 | 41418.0 | 59650.5 | 53009.2 [64, 256, 16, 16, 16] 4x4 stride 4 | 2362.0 | 4210.3 | 2114.0 [64, 256, 16, 16, 16] 4x4 stride 2 | 4102.4 | 5897.4 | 3179.7 [64, 256, 16, 16, 16] 4x4 stride 1 | 11339.3 | 13116.6 | 10032.6 [64, 256, 16, 16, 16] 2x2 stride 4 | 1709.7 | 3506.7 | 1423.6 [64, 256, 16, 16, 16] 2x2 stride 2 | 2966.6 | 4760.8 | 2499.3 [64, 256, 16, 16, 16] 2x2 stride 1 | 6998.4 | 8797.3 | 6152.0 [64, 256, 4, 4, 4] 4x4 stride 4 | 173.0 | 176.3 | 127.9 [64, 256, 4, 4, 4] 4x4 stride 2 | 149.1 | 176.3 | 125.5 [64, 256, 4, 4, 4] 4x4 stride 1 | 150.0 | 177.2 | 125.6 [64, 256, 4, 4, 4] 2x2 stride 4 | 158.0 | 192.7 | 127.9 [64, 256, 4, 4, 4] 2x2 stride 2 | 169.7 | 199.2 | 125.3 [64, 256, 4, 4, 4] 2x2 stride 1 | 289.6 | 318.2 | 116.5 [64, 64, 32, 32, 32] 4x4 stride 4 | 3914.4 | 6993.3 | 4141.4 [64, 64, 32, 32, 32] 4x4 stride 2 | 6107.4 | 9186.4 | 6378.5 [64, 64, 32, 32, 32] 4x4 stride 1 | 17920.0 | 20993.5 | 23891.1 [64, 64, 32, 32, 32] 2x2 stride 4 | 3029.7 | 6112.6 | 2895.6 [64, 64, 32, 32, 32] 2x2 stride 2 | 4787.8 | 7870.6 | 4724.8 [64, 64, 32, 32, 32] 2x2 stride 1 | 10366.4 | 13446.4 | 12603.8 [64, 64, 16, 16, 16] 4x4 stride 4 | 605.8 | 962.9 | 499.7 [64, 64, 16, 16, 16] 4x4 stride 2 | 1037.0 | 1394.8 | 791.6 [64, 64, 16, 16, 16] 4x4 stride 1 | 2835.4 | 3191.8 | 2484.3 [64, 64, 16, 16, 16] 2x2 stride 4 | 438.6 | 795.7 | 368.6 [64, 64, 16, 16, 16] 2x2 stride 2 | 749.1 | 1108.0 | 612.0 [64, 64, 16, 16, 16] 2x2 stride 1 | 1756.4 | 2112.2 | 1538.5 [64, 64, 4, 4, 4] 4x4 stride 4 | 132.6 | 163.9 | 115.4 [64, 64, 4, 4, 4] 4x4 stride 2 | 129.3 | 153.7 | 117.8 [64, 64, 4, 4, 4] 4x4 stride 1 | 128.0 | 153.8 | 117.6 [64, 64, 4, 4, 4] 2x2 stride 4 | 128.2 | 154.1 | 117.5 [64, 64, 4, 4, 4] 2x2 stride 2 | 130.5 | 157.3 | 117.6 [64, 64, 4, 4, 4] 2x2 stride 1 | 128.8 | 156.4 | 120.6 [64, 3, 32, 32, 32] 4x4 stride 4 | 200.4 | 261.0 | 228.8 [64, 3, 32, 32, 32] 4x4 stride 2 | 305.3 | 366.5 | 344.4 [64, 3, 32, 32, 32] 4x4 stride 1 | 860.9 | 922.1 | 1136.0 [64, 3, 32, 32, 32] 2x2 stride 4 | 157.0 | 216.9 | 158.1 [64, 3, 32, 32, 32] 2x2 stride 2 | 240.5 | 300.9 | 247.7 [64, 3, 32, 32, 32] 2x2 stride 1 | 503.5 | 565.1 | 609.8 [64, 3, 16, 16, 16] 4x4 stride 4 | 136.0 | 159.0 | 120.3 [64, 3, 16, 16, 16] 4x4 stride 2 | 131.2 | 156.9 | 120.0 [64, 3, 16, 16, 16] 4x4 stride 1 | 146.6 | 158.5 | 123.8 [64, 3, 16, 16, 16] 2x2 stride 4 | 133.8 | 158.4 | 117.1 [64, 3, 16, 16, 16] 2x2 stride 2 | 132.1 | 160.8 | 117.9 [64, 3, 16, 16, 16] 2x2 stride 1 | 133.7 | 174.4 | 118.0 [64, 3, 4, 4, 4] 4x4 stride 4 | 156.8 | 166.2 | 119.4 [64, 3, 4, 4, 4] 4x4 stride 2 | 126.8 | 150.4 | 118.2 [64, 3, 4, 4, 4] 4x4 stride 1 | 125.2 | 151.7 | 117.8 [64, 3, 4, 4, 4] 2x2 stride 4 | 127.3 | 152.7 | 116.2 [64, 3, 4, 4, 4] 2x2 stride 2 | 128.6 | 153.3 | 114.6 [64, 3, 4, 4, 4] 2x2 stride 1 | 128.6 | 153.5 | 114.7 [16, 256, 32, 32, 32] 4x4 stride 4 | 3921.7 | 8445.7 | 4064.7 [16, 256, 32, 32, 32] 4x4 stride 2 | 6111.7 | 10630.0 | 6944.4 [16, 256, 32, 32, 32] 4x4 stride 1 | 17938.9 | 22896.8 | 26648.7 [16, 256, 32, 32, 32] 2x2 stride 4 | 3029.6 | 7552.7 | 2840.9 [16, 256, 32, 32, 32] 2x2 stride 2 | 4788.0 | 9322.1 | 5110.5 [16, 256, 32, 32, 32] 2x2 stride 1 | 10363.7 | 14885.9 | 13213.6 [16, 256, 16, 16, 16] 4x4 stride 4 | 606.0 | 1059.1 | 535.9 [16, 256, 16, 16, 16] 4x4 stride 2 | 1037.5 | 1491.5 | 822.3 [16, 256, 16, 16, 16] 4x4 stride 1 | 2835.4 | 3306.8 | 2522.8 [16, 256, 16, 16, 16] 2x2 stride 4 | 438.6 | 892.3 | 369.0 [16, 256, 16, 16, 16] 2x2 stride 2 | 749.2 | 1203.7 | 638.7 [16, 256, 16, 16, 16] 2x2 stride 1 | 1756.1 | 2212.5 | 1547.0 [16, 256, 4, 4, 4] 4x4 stride 4 | 159.6 | 187.6 | 117.6 [16, 256, 4, 4, 4] 4x4 stride 2 | 161.1 | 185.5 | 117.3 [16, 256, 4, 4, 4] 4x4 stride 1 | 160.0 | 148.1 | 117.8 [16, 256, 4, 4, 4] 2x2 stride 4 | 123.9 | 148.3 | 117.6 [16, 256, 4, 4, 4] 2x2 stride 2 | 126.0 | 151.7 | 117.4 [16, 256, 4, 4, 4] 2x2 stride 1 | 127.1 | 152.3 | 117.9 [16, 64, 32, 32, 32] 4x4 stride 4 | 983.5 | 1756.7 | 1067.8 [16, 64, 32, 32, 32] 4x4 stride 2 | 1542.4 | 2315.2 | 1621.5 [16, 64, 32, 32, 32] 4x4 stride 1 | 4498.7 | 5273.4 | 6006.7 [16, 64, 32, 32, 32] 2x2 stride 4 | 767.2 | 1543.4 | 736.7 [16, 64, 32, 32, 32] 2x2 stride 2 | 1207.8 | 1981.5 | 1197.0 [16, 64, 32, 32, 32] 2x2 stride 1 | 2603.3 | 3367.5 | 3161.9 [16, 64, 16, 16, 16] 4x4 stride 4 | 169.5 | 264.6 | 142.8 [16, 64, 16, 16, 16] 4x4 stride 2 | 274.6 | 368.9 | 216.8 [16, 64, 16, 16, 16] 4x4 stride 1 | 723.3 | 820.4 | 643.2 [16, 64, 16, 16, 16] 2x2 stride 4 | 131.4 | 216.0 | 116.1 [16, 64, 16, 16, 16] 2x2 stride 2 | 199.9 | 295.0 | 166.8 ``` CC ptrblck Pull Request resolved: #80748 Approved by: https://github.com/ngimel Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/3b78c5682b483086f66d875749f94b7551072a05 Reviewed By: mehtanirav Differential Revision: D37717260 Pulled By: mehtanirav fbshipit-source-id: a6a00635d1c2bd89fd8bb7a2d40071b2add5e044
MaxPool3D currently converts inputs implicitly to channels-first (via
.contiguous()
) which may yield unexpected regressions in workloads that expect a full channels-last path. This PR preserves the channels-last format in MaxPool3D while attempting to avoid seriously regressing performance.Currently, typical case (kernel size == 2 == stride) looks good, but larger kernel sizes (>4) or the unusual case of stride 1 can sometimes be slower than converting to channels-first before doing MaxPool3D.
Additionally, this PR adds a test for 64bit-indexing backwards as testing of these changes uncovered an IMA for large tensors when doing the backwards pass with MaxPool3D.
Performance comparison on A6000:
Performance comparison on V100:
(these times have been updated after working around some noisy measurements in my setup)
CC @ptrblck