Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Description: Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype: ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s, dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on `main`: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64); iota = None mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1); convert_element_type = None add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0); mul = None convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None # File: check_arange_decomp.py:9 in func, code: return s + a add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10); convert_element_type_1 = None return (add_1,) ``` and C++ ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<double>(tmp0); // <---- useless ops auto tmp2 = static_cast<double>(1.0); // <---- auto tmp3 = decltype(tmp1)(tmp1 * tmp2); // <---- auto tmp4 = static_cast<double>(0.0); // <---- auto tmp5 = decltype(tmp3)(tmp3 + tmp4); // <---- auto tmp6 = c10::convert<float>(tmp5); auto tmp7 = static_cast<float>(10.0); auto tmp8 = decltype(tmp6)(tmp6 + tmp7); out_ptr0[static_cast<long>(x0)] = tmp8; } } } ``` However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up. ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s).to(dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on `main`: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32); iota = None # File: check_arange_decomp.py:15 in func, code: return s + a add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10); convert_element_type = None return (add,) ``` C++ on `main` ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<float>(tmp0); auto tmp2 = static_cast<float>(10.0); auto tmp3 = decltype(tmp1)(tmp1 + tmp2); out_ptr0[static_cast<long>(x0)] = tmp3; } } } ``` For example, the speed-up seen on upsample_nearest2d on cpu: ``` [----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------] | Eager (2.3.0a0+gitb4324ed) PR | Compiled (2.3.0a0+gitb4324ed) PR | Compiled (2.3.0a0+git0d1e705) Nightly | speed-up PR vs Nightly | Eager (2.3.0a0+git0d1e705) Nightly 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 287.988 (+-10.399) | 200.034 (+-8.630) | 285.143 (+-8.412) | 1.425 (+-0.000) | 287.991 (+-11.302) Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 697.206 (+-27.033) | 171.650 (+-7.381) | 193.280 (+-5.840) | 1.126 (+-0.000) | 701.642 (+-26.461) Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 149.149 (+-6.045) | 222.780 (+-6.852) | 299.968 (+-12.354) | 1.346 (+-0.000) | 145.055 (+-7.232) Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 596.741 (+-27.970) | 205.923 (+-8.648) | 233.912 (+-7.742) | 1.136 (+-0.000) | 598.000 (+-25.630) Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 1095.734 (+-51.658) | 700.850 (+-24.852) | 1044.255 (+-38.216) | 1.490 (+-0.000) | 1097.977 (+-35.521) Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 2741.813 (+-122.917) | 583.073 (+-16.998) | 665.029 (+-36.331) | 1.141 (+-0.000) | 2722.388 (+-116.263) Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256) | 578.183 (+-37.266) | 833.295 (+-42.264) | 1131.341 (+-54.710) | 1.358 (+-0.000) | 584.953 (+-45.549) Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256) | 2332.508 (+-103.556) | 840.194 (+-47.664) | 935.625 (+-47.467) | 1.114 (+-0.000) | 2334.314 (+-91.644) Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 272.631 (+-11.348) | 195.988 (+-5.748) | 274.021 (+-9.475) | 1.398 (+-0.000) | 272.752 (+-12.716) Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 640.409 (+-25.465) | 164.773 (+-7.372) | 185.018 (+-8.349) | 1.123 (+-0.000) | 639.390 (+-30.761) Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 158.602 (+-6.593) | 220.478 (+-6.809) | 286.376 (+-8.981) | 1.299 (+-0.000) | 158.557 (+-6.143) Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 548.903 (+-22.889) | 202.788 (+-9.158) | 227.404 (+-8.995) | 1.121 (+-0.000) | 554.096 (+-21.330) Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 1036.061 (+-35.285) | 680.728 (+-30.925) | 986.254 (+-42.732) | 1.449 (+-0.000) | 1038.718 (+-43.070) Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 2504.520 (+-125.805) | 550.067 (+-21.383) | 628.000 (+-27.589) | 1.142 (+-0.000) | 2523.134 (+-113.336) Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300) | 1058.188 (+-57.853) | 1216.427 (+-76.160) | 1380.231 (+-98.939) | 1.135 (+-0.000) | 1057.031 (+-66.075) Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300) | 2305.911 (+-116.864) | 1080.189 (+-79.934) | 1141.561 (+-67.959) | 1.057 (+-0.000) | 2306.606 (+-121.544) Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 1689.489 (+-60.579) | 1077.401 (+-44.948) | 1634.264 (+-64.340) | 1.517 (+-0.000) | 1693.945 (+-67.998) Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 4198.368 (+-179.096) | 886.656 (+-30.355) | 1028.568 (+-46.310) | 1.160 (+-0.000) | 4174.351 (+-141.020) Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 716.572 (+-51.954) | 1175.864 (+-52.191) | 1674.373 (+-51.815) | 1.424 (+-0.000) | 715.724 (+-41.104) Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 3604.989 (+-132.489) | 1096.933 (+-54.290) | 1270.347 (+-60.932) | 1.158 (+-0.000) | 3601.864 (+-140.218) Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 6721.610 (+-355.997) | 4203.213 (+-134.362) | 6423.763 (+-225.311) | 1.528 (+-0.000) | 6715.626 (+-288.233) Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 16695.467 (+-709.620) | 3460.013 (+-149.456) | 4001.810 (+-218.093) | 1.157 (+-0.000) | 16621.138 (+-713.320) Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700) | 3020.017 (+-147.314) | 4743.164 (+-135.850) | 6709.494 (+-281.025) | 1.415 (+-0.000) | 3015.602 (+-105.852) Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700) | 14456.688 (+-752.839) | 5150.893 (+-201.571) | 5737.315 (+-138.011) | 1.114 (+-0.000) | 14464.472 (+-720.027) Times are in microseconds (us). ``` This PR fixes arrange decomp such that `arange(s, dtype=torch.float32)` directly provides better IR and generated code. Code: ```python import torch def func(x): s = x.shape[-1] a = torch.arange(s, dtype=torch.float32) return s + a c_func = torch.compile(func) out = c_func(torch.rand(10)) ``` Graph on this PR: ``` ===== Forward graph 0 ===== /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self): # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32) iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False) convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32); iota = None # File: check_arange_decomp.py:9 in func, code: return s + a add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10); convert_element_type = None return (add,) ``` and C++ on this PR: ```c++ extern "C" void kernel(float* out_ptr0) { { #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(x0); auto tmp1 = c10::convert<float>(tmp0); auto tmp2 = static_cast<float>(10.0); auto tmp3 = decltype(tmp1)(tmp1 + tmp2); out_ptr0[static_cast<long>(x0)] = tmp3; } } } ``` [ghstack-poisoned]
- Loading branch information