-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed arange decomp for float dtype #123445
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Description: Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype: ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s, dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on `main`: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64); iota = None mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1); convert_element_type = None add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0); mul = None convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None # File: check_arange_decomp.py:9 in func, code: return s + a add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10); convert_element_type_1 = None return (add_1,) ``` and C++ ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<double>(tmp0); // <---- useless ops auto tmp2 = static_cast<double>(1.0); // <---- auto tmp3 = decltype(tmp1)(tmp1 * tmp2); // <---- auto tmp4 = static_cast<double>(0.0); // <---- auto tmp5 = decltype(tmp3)(tmp3 + tmp4); // <---- auto tmp6 = c10::convert<float>(tmp5); auto tmp7 = static_cast<float>(10.0); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); out_ptr0[static_cast<long>(x0)] = tmp8; } } } ``` However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up. ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s).to(dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on `main`: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32); iota = None # File: check_arange_decomp.py:15 in func, code: return s + a add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10); convert_element_type = None return (add,) ``` C++ on `main` ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<float>(tmp0); auto tmp2 = static_cast<float>(10.0); auto tmp3 = decltype(tmp1)(tmp1 + tmp2); out_ptr0[static_cast<long>(x0)] = tmp3; } } } ``` For example, the speed-up seen on upsample_nearest2d on cpu: ``` [----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------] | Eager (2.3.0a0+gitb4324ed) PR | Compiled (2.3.0a0+gitb4324ed) PR | Compiled (2.3.0a0+git0d1e705) Nightly | speed-up PR vs Nightly | Eager (2.3.0a0+git0d1e705) Nightly 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 287.988 (+-10.399) | 200.034 (+-8.630) | 285.143 (+-8.412) | 1.425 (+-0.000) | 287.991 (+-11.302) Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 697.206 (+-27.033) | 171.650 (+-7.381) | 193.280 (+-5.840) | 1.126 (+-0.000) | 701.642 (+-26.461) Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 149.149 (+-6.045) | 222.780 (+-6.852) | 299.968 (+-12.354) | 1.346 (+-0.000) | 145.055 (+-7.232) Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 596.741 (+-27.970) | 205.923 (+-8.648) | 233.912 (+-7.742) | 1.136 (+-0.000) | 598.000 (+-25.630) Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 1095.734 (+-51.658) | 700.850 (+-24.852) | 1044.255 (+-38.216) | 1.490 (+-0.000) | 1097.977 (+-35.521) Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 2741.813 (+-122.917) | 583.073 (+-16.998) | 665.029 (+-36.331) | 1.141 (+-0.000) | 2722.388 (+-116.263) Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 578.183 (+-37.266) | 833.295 (+-42.264) | 1131.341 (+-54.710) | 1.358 (+-0.000) | 584.953 (+-45.549) Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 2332.508 (+-103.556) | 840.194 (+-47.664) | 935.625 (+-47.467) | 1.114 (+-0.000) | 2334.314 (+-91.644) Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 272.631 (+-11.348) | 195.988 (+-5.748) | 274.021 (+-9.475) | 1.398 (+-0.000) | 272.752 (+-12.716) Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 640.409 (+-25.465) | 164.773 (+-7.372) | 185.018 (+-8.349) | 1.123 (+-0.000) | 639.390 (+-30.761) Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 158.602 (+-6.593) | 220.478 (+-6.809) | 286.376 (+-8.981) | 1.299 (+-0.000) | 158.557 (+-6.143) Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 548.903 (+-22.889) | 202.788 (+-9.158) | 227.404 (+-8.995) | 1.121 (+-0.000) | 554.096 (+-21.330) Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 1036.061 (+-35.285) | 680.728 (+-30.925) | 986.254 (+-42.732) | 1.449 (+-0.000) | 1038.718 (+-43.070) Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 2504.520 (+-125.805) | 550.067 (+-21.383) | 628.000 (+-27.589) | 1.142 (+-0.000) | 2523.134 (+-113.336) Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 1058.188 (+-57.853) | 1216.427 (+-76.160) | 1380.231 (+-98.939) | 1.135 (+-0.000) | 1057.031 (+-66.075) Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 2305.911 (+-116.864) | 1080.189 (+-79.934) | 1141.561 (+-67.959) | 1.057 (+-0.000) | 2306.606 (+-121.544) Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 1689.489 (+-60.579) | 1077.401 (+-44.948) | 1634.264 (+-64.340) | 1.517 (+-0.000) | 1693.945 (+-67.998) Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 4198.368 (+-179.096) | 886.656 (+-30.355) | 1028.568 (+-46.310) | 1.160 (+-0.000) | 4174.351 (+-141.020) Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 716.572 (+-51.954) | 1175.864 (+-52.191) | 1674.373 (+-51.815) | 1.424 (+-0.000) | 715.724 (+-41.104) Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 3604.989 (+-132.489) | 1096.933 (+-54.290) | 1270.347 (+-60.932) | 1.158 (+-0.000) | 3601.864 (+-140.218) Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 6721.610 (+-355.997) | 4203.213 (+-134.362) | 6423.763 (+-225.311) | 1.528 (+-0.000) | 6715.626 (+-288.233) Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 16695.467 (+-709.620) | 3460.013 (+-149.456) | 4001.810 (+-218.093) | 1.157 (+-0.000) | 16621.138 (+-713.320) Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 3020.017 (+-147.314) | 4743.164 (+-135.850) | 6709.494 (+-281.025) | 1.415 (+-0.000) | 3015.602 (+-105.852) Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 14456.688 (+-752.839) | 5150.893 (+-201.571) | 5737.315 (+-138.011) | 1.114 (+-0.000) | 14464.472 (+-720.027) Times are in microseconds (us). ``` This PR fixes arrange decomp such that `arange(s, dtype=torch.float32)` directly provides better IR and generated code. Code: ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s, dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on this PR: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32); iota = None # File: check_arange_decomp.py:9 in func, code: return s + a add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10); convert_element_type = None return (add,) ``` and C++ on this PR: ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<float>(tmp0); auto tmp2 = static_cast<float>(10.0); auto tmp3 = decltype(tmp1)(tmp1 + tmp2); out_ptr0[static_cast<long>(x0)] = tmp3; } } } ``` [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123445
Note: Links to docs will display an error until the docs builds have been completed. ❌ 9 New Failures, 3 Unrelated FailuresAs of commit e791f44 with merge base 5b0ce8f ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
vfdev-5
added a commit
that referenced
this pull request
Apr 5, 2024
Description: Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype: ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s, dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on `main`: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64); iota = None mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1); convert_element_type = None add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0); mul = None convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None # File: check_arange_decomp.py:9 in func, code: return s + a add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10); convert_element_type_1 = None return (add_1,) ``` and C++ ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<double>(tmp0); // <---- useless ops auto tmp2 = static_cast<double>(1.0); // <---- auto tmp3 = decltype(tmp1)(tmp1 * tmp2); // <---- auto tmp4 = static_cast<double>(0.0); // <---- auto tmp5 = decltype(tmp3)(tmp3 + tmp4); // <---- auto tmp6 = c10::convert<float>(tmp5); auto tmp7 = static_cast<float>(10.0); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); out_ptr0[static_cast<long>(x0)] = tmp8; } } } ``` However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up. ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s).to(dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on `main`: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32); iota = None # File: check_arange_decomp.py:15 in func, code: return s + a add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10); convert_element_type = None return (add,) ``` C++ on `main` ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<float>(tmp0); auto tmp2 = static_cast<float>(10.0); auto tmp3 = decltype(tmp1)(tmp1 + tmp2); out_ptr0[static_cast<long>(x0)] = tmp3; } } } ``` For example, the speed-up seen on upsample_nearest2d on cpu: ``` [----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------] | Eager (2.3.0a0+gitb4324ed) PR | Compiled (2.3.0a0+gitb4324ed) PR | Compiled (2.3.0a0+git0d1e705) Nightly | speed-up PR vs Nightly | Eager (2.3.0a0+git0d1e705) Nightly 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 287.988 (+-10.399) | 200.034 (+-8.630) | 285.143 (+-8.412) | 1.425 (+-0.000) | 287.991 (+-11.302) Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 697.206 (+-27.033) | 171.650 (+-7.381) | 193.280 (+-5.840) | 1.126 (+-0.000) | 701.642 (+-26.461) Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 149.149 (+-6.045) | 222.780 (+-6.852) | 299.968 (+-12.354) | 1.346 (+-0.000) | 145.055 (+-7.232) Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 596.741 (+-27.970) | 205.923 (+-8.648) | 233.912 (+-7.742) | 1.136 (+-0.000) | 598.000 (+-25.630) Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 1095.734 (+-51.658) | 700.850 (+-24.852) | 1044.255 (+-38.216) | 1.490 (+-0.000) | 1097.977 (+-35.521) Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 2741.813 (+-122.917) | 583.073 (+-16.998) | 665.029 (+-36.331) | 1.141 (+-0.000) | 2722.388 (+-116.263) Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 578.183 (+-37.266) | 833.295 (+-42.264) | 1131.341 (+-54.710) | 1.358 (+-0.000) | 584.953 (+-45.549) Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 2332.508 (+-103.556) | 840.194 (+-47.664) | 935.625 (+-47.467) | 1.114 (+-0.000) | 2334.314 (+-91.644) Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 272.631 (+-11.348) | 195.988 (+-5.748) | 274.021 (+-9.475) | 1.398 (+-0.000) | 272.752 (+-12.716) Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 640.409 (+-25.465) | 164.773 (+-7.372) | 185.018 (+-8.349) | 1.123 (+-0.000) | 639.390 (+-30.761) Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 158.602 (+-6.593) | 220.478 (+-6.809) | 286.376 (+-8.981) | 1.299 (+-0.000) | 158.557 (+-6.143) Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 548.903 (+-22.889) | 202.788 (+-9.158) | 227.404 (+-8.995) | 1.121 (+-0.000) | 554.096 (+-21.330) Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 1036.061 (+-35.285) | 680.728 (+-30.925) | 986.254 (+-42.732) | 1.449 (+-0.000) | 1038.718 (+-43.070) Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 2504.520 (+-125.805) | 550.067 (+-21.383) | 628.000 (+-27.589) | 1.142 (+-0.000) | 2523.134 (+-113.336) Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 1058.188 (+-57.853) | 1216.427 (+-76.160) | 1380.231 (+-98.939) | 1.135 (+-0.000) | 1057.031 (+-66.075) Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 2305.911 (+-116.864) | 1080.189 (+-79.934) | 1141.561 (+-67.959) | 1.057 (+-0.000) | 2306.606 (+-121.544) Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 1689.489 (+-60.579) | 1077.401 (+-44.948) | 1634.264 (+-64.340) | 1.517 (+-0.000) | 1693.945 (+-67.998) Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 4198.368 (+-179.096) | 886.656 (+-30.355) | 1028.568 (+-46.310) | 1.160 (+-0.000) | 4174.351 (+-141.020) Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 716.572 (+-51.954) | 1175.864 (+-52.191) | 1674.373 (+-51.815) | 1.424 (+-0.000) | 715.724 (+-41.104) Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 3604.989 (+-132.489) | 1096.933 (+-54.290) | 1270.347 (+-60.932) | 1.158 (+-0.000) | 3601.864 (+-140.218) Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 6721.610 (+-355.997) | 4203.213 (+-134.362) | 6423.763 (+-225.311) | 1.528 (+-0.000) | 6715.626 (+-288.233) Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 16695.467 (+-709.620) | 3460.013 (+-149.456) | 4001.810 (+-218.093) | 1.157 (+-0.000) | 16621.138 (+-713.320) Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 3020.017 (+-147.314) | 4743.164 (+-135.850) | 6709.494 (+-281.025) | 1.415 (+-0.000) | 3015.602 (+-105.852) Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 14456.688 (+-752.839) | 5150.893 (+-201.571) | 5737.315 (+-138.011) | 1.114 (+-0.000) | 14464.472 (+-720.027) Times are in microseconds (us). ``` This PR fixes arrange decomp such that `arange(s, dtype=torch.float32)` directly provides better IR and generated code. Code: ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s, dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on this PR: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32); iota = None # File: check_arange_decomp.py:9 in func, code: return s + a add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10); convert_element_type = None return (add,) ``` and C++ on this PR: ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<float>(tmp0); auto tmp2 = static_cast<float>(10.0); auto tmp3 = decltype(tmp1)(tmp1 + tmp2); out_ptr0[static_cast<long>(x0)] = tmp3; } } } ``` ghstack-source-id: b2cd195b3bed03784df46632e5f17d3ff7dabf13 Pull Request resolved: #123445
leslie-fang-intel
added a commit
that referenced
this pull request
Apr 28, 2024
…t per tensor and refactor quant pattern" **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
leslie-fang-intel
added a commit
that referenced
this pull request
Apr 28, 2024
… refactor quant pattern" **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
leslie-fang-intel
added a commit
that referenced
this pull request
May 6, 2024
…t per tensor and refactor quant pattern" **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
leslie-fang-intel
added a commit
that referenced
this pull request
May 6, 2024
… refactor quant pattern" **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
leslie-fang-intel
added a commit
that referenced
this pull request
May 7, 2024
…t per tensor and refactor quant pattern" **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
leslie-fang-intel
added a commit
that referenced
this pull request
May 7, 2024
… refactor quant pattern" **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this pull request
May 9, 2024
…uant pattern (#124041) **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` Pull Request resolved: #124041 Approved by: https://github.com/peterbell10, https://github.com/jgong5
leslie-fang-intel
added a commit
that referenced
this pull request
May 9, 2024
…t per tensor and refactor quant pattern" **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
leslie-fang-intel
added a commit
that referenced
this pull request
May 9, 2024
… refactor quant pattern" **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this pull request
May 9, 2024
…uant pattern (#124041) **Summary** Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in #123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` Pull Request resolved: #124041 Approved by: https://github.com/peterbell10, https://github.com/jgong5
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Description:
Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype:
Graph on
main
:and C++
However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up.
Graph on
main
:C++ on
main
For example, the speed-up seen on upsample_nearest2d on cpu:
This PR fixes arrange decomp such that
arange(s, dtype=torch.float32)
directly provides better IR and generated code.Code:
Graph on this PR:
and C++ on this PR:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang