Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed arange decomp for float dtype #123445

Closed
wants to merge 1 commit into from

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Apr 5, 2024

Stack from ghstack (oldest at bottom):

Description:

Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype:

import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))

Graph on main:

 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64);  iota = None
        mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1);  convert_element_type = None
        add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10);  convert_element_type_1 = None
        return (add_1,)

and C++

extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<double>(tmp0);   // <---- useless ops
            auto tmp2 = static_cast<double>(1.0);     // <----
            auto tmp3 = decltype(tmp1)(tmp1 * tmp2);  // <----
            auto tmp4 = static_cast<double>(0.0);     // <----
            auto tmp5 = decltype(tmp3)(tmp3 + tmp4);  // <----
            auto tmp6 = c10::convert<float>(tmp5);
            auto tmp7 = static_cast<float>(10.0);
            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
            out_ptr0[static_cast<long>(x0)] = tmp8;
        }
    }
}

However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up.

import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s).to(dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))

Graph on main:

 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:15 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)

C++ on main

extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}

For example, the speed-up seen on upsample_nearest2d on cpu:

[----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                |  Eager (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+git0d1e705) Nightly  |  speed-up PR vs Nightly  |  Eager (2.3.0a0+git0d1e705) Nightly
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |        287.988 (+-10.399)       |         200.034 (+-8.630)          |            285.143 (+-8.412)            |     1.425 (+-0.000)      |          287.991 (+-11.302)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |        697.206 (+-27.033)       |         171.650 (+-7.381)          |            193.280 (+-5.840)            |     1.126 (+-0.000)      |          701.642 (+-26.461)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        149.149 (+-6.045)        |         222.780 (+-6.852)          |            299.968 (+-12.354)           |     1.346 (+-0.000)      |          145.055 (+-7.232)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |        596.741 (+-27.970)       |         205.923 (+-8.648)          |            233.912 (+-7.742)            |     1.136 (+-0.000)      |          598.000 (+-25.630)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |       1095.734 (+-51.658)       |         700.850 (+-24.852)         |           1044.255 (+-38.216)           |     1.490 (+-0.000)      |         1097.977 (+-35.521)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |       2741.813 (+-122.917)      |         583.073 (+-16.998)         |            665.029 (+-36.331)           |     1.141 (+-0.000)      |         2722.388 (+-116.263)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        578.183 (+-37.266)       |         833.295 (+-42.264)         |           1131.341 (+-54.710)           |     1.358 (+-0.000)      |          584.953 (+-45.549)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |       2332.508 (+-103.556)      |         840.194 (+-47.664)         |            935.625 (+-47.467)           |     1.114 (+-0.000)      |         2334.314 (+-91.644)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |        272.631 (+-11.348)       |         195.988 (+-5.748)          |            274.021 (+-9.475)            |     1.398 (+-0.000)      |          272.752 (+-12.716)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |        640.409 (+-25.465)       |         164.773 (+-7.372)          |            185.018 (+-8.349)            |     1.123 (+-0.000)      |          639.390 (+-30.761)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |        158.602 (+-6.593)        |         220.478 (+-6.809)          |            286.376 (+-8.981)            |     1.299 (+-0.000)      |          158.557 (+-6.143)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |        548.903 (+-22.889)       |         202.788 (+-9.158)          |            227.404 (+-8.995)            |     1.121 (+-0.000)      |          554.096 (+-21.330)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |       1036.061 (+-35.285)       |         680.728 (+-30.925)         |            986.254 (+-42.732)           |     1.449 (+-0.000)      |         1038.718 (+-43.070)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |       2504.520 (+-125.805)      |         550.067 (+-21.383)         |            628.000 (+-27.589)           |     1.142 (+-0.000)      |         2523.134 (+-113.336)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |       1058.188 (+-57.853)       |        1216.427 (+-76.160)         |           1380.231 (+-98.939)           |     1.135 (+-0.000)      |         1057.031 (+-66.075)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |       2305.911 (+-116.864)      |        1080.189 (+-79.934)         |           1141.561 (+-67.959)           |     1.057 (+-0.000)      |         2306.606 (+-121.544)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       1689.489 (+-60.579)       |        1077.401 (+-44.948)         |           1634.264 (+-64.340)           |     1.517 (+-0.000)      |         1693.945 (+-67.998)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |       4198.368 (+-179.096)      |         886.656 (+-30.355)         |           1028.568 (+-46.310)           |     1.160 (+-0.000)      |         4174.351 (+-141.020)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |        716.572 (+-51.954)       |        1175.864 (+-52.191)         |           1674.373 (+-51.815)           |     1.424 (+-0.000)      |          715.724 (+-41.104)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |       3604.989 (+-132.489)      |        1096.933 (+-54.290)         |           1270.347 (+-60.932)           |     1.158 (+-0.000)      |         3601.864 (+-140.218)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       6721.610 (+-355.997)      |        4203.213 (+-134.362)        |           6423.763 (+-225.311)          |     1.528 (+-0.000)      |         6715.626 (+-288.233)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |      16695.467 (+-709.620)      |        3460.013 (+-149.456)        |           4001.810 (+-218.093)          |     1.157 (+-0.000)      |        16621.138 (+-713.320)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |       3020.017 (+-147.314)      |        4743.164 (+-135.850)        |           6709.494 (+-281.025)          |     1.415 (+-0.000)      |         3015.602 (+-105.852)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |      14456.688 (+-752.839)      |        5150.893 (+-201.571)        |           5737.315 (+-138.011)          |     1.114 (+-0.000)      |        14464.472 (+-720.027)

Times are in microseconds (us).

This PR fixes arrange decomp such that arange(s, dtype=torch.float32) directly provides better IR and generated code.

Code:

import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))

Graph on this PR:

 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)

and C++ on this PR:

extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang

Description:

Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64);  iota = None
        mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1);  convert_element_type = None
        add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10);  convert_element_type_1 = None
        return (add_1,)
```
and C++
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<double>(tmp0);   // <---- useless ops
            auto tmp2 = static_cast<double>(1.0);     // <----
            auto tmp3 = decltype(tmp1)(tmp1 * tmp2);  // <----
            auto tmp4 = static_cast<double>(0.0);     // <----
            auto tmp5 = decltype(tmp3)(tmp3 + tmp4);  // <----
            auto tmp6 = c10::convert<float>(tmp5);
            auto tmp7 = static_cast<float>(10.0);
            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
            out_ptr0[static_cast<long>(x0)] = tmp8;
        }
    }
}
```

However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up.
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s).to(dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:15 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)
```

C++ on `main`
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

For example, the speed-up seen on upsample_nearest2d on cpu:
```
[----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                |  Eager (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+git0d1e705) Nightly  |  speed-up PR vs Nightly  |  Eager (2.3.0a0+git0d1e705) Nightly
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |        287.988 (+-10.399)       |         200.034 (+-8.630)          |            285.143 (+-8.412)            |     1.425 (+-0.000)      |          287.991 (+-11.302)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |        697.206 (+-27.033)       |         171.650 (+-7.381)          |            193.280 (+-5.840)            |     1.126 (+-0.000)      |          701.642 (+-26.461)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        149.149 (+-6.045)        |         222.780 (+-6.852)          |            299.968 (+-12.354)           |     1.346 (+-0.000)      |          145.055 (+-7.232)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |        596.741 (+-27.970)       |         205.923 (+-8.648)          |            233.912 (+-7.742)            |     1.136 (+-0.000)      |          598.000 (+-25.630)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |       1095.734 (+-51.658)       |         700.850 (+-24.852)         |           1044.255 (+-38.216)           |     1.490 (+-0.000)      |         1097.977 (+-35.521)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |       2741.813 (+-122.917)      |         583.073 (+-16.998)         |            665.029 (+-36.331)           |     1.141 (+-0.000)      |         2722.388 (+-116.263)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        578.183 (+-37.266)       |         833.295 (+-42.264)         |           1131.341 (+-54.710)           |     1.358 (+-0.000)      |          584.953 (+-45.549)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |       2332.508 (+-103.556)      |         840.194 (+-47.664)         |            935.625 (+-47.467)           |     1.114 (+-0.000)      |         2334.314 (+-91.644)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |        272.631 (+-11.348)       |         195.988 (+-5.748)          |            274.021 (+-9.475)            |     1.398 (+-0.000)      |          272.752 (+-12.716)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |        640.409 (+-25.465)       |         164.773 (+-7.372)          |            185.018 (+-8.349)            |     1.123 (+-0.000)      |          639.390 (+-30.761)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |        158.602 (+-6.593)        |         220.478 (+-6.809)          |            286.376 (+-8.981)            |     1.299 (+-0.000)      |          158.557 (+-6.143)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |        548.903 (+-22.889)       |         202.788 (+-9.158)          |            227.404 (+-8.995)            |     1.121 (+-0.000)      |          554.096 (+-21.330)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |       1036.061 (+-35.285)       |         680.728 (+-30.925)         |            986.254 (+-42.732)           |     1.449 (+-0.000)      |         1038.718 (+-43.070)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |       2504.520 (+-125.805)      |         550.067 (+-21.383)         |            628.000 (+-27.589)           |     1.142 (+-0.000)      |         2523.134 (+-113.336)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |       1058.188 (+-57.853)       |        1216.427 (+-76.160)         |           1380.231 (+-98.939)           |     1.135 (+-0.000)      |         1057.031 (+-66.075)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |       2305.911 (+-116.864)      |        1080.189 (+-79.934)         |           1141.561 (+-67.959)           |     1.057 (+-0.000)      |         2306.606 (+-121.544)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       1689.489 (+-60.579)       |        1077.401 (+-44.948)         |           1634.264 (+-64.340)           |     1.517 (+-0.000)      |         1693.945 (+-67.998)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |       4198.368 (+-179.096)      |         886.656 (+-30.355)         |           1028.568 (+-46.310)           |     1.160 (+-0.000)      |         4174.351 (+-141.020)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |        716.572 (+-51.954)       |        1175.864 (+-52.191)         |           1674.373 (+-51.815)           |     1.424 (+-0.000)      |          715.724 (+-41.104)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |       3604.989 (+-132.489)      |        1096.933 (+-54.290)         |           1270.347 (+-60.932)           |     1.158 (+-0.000)      |         3601.864 (+-140.218)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       6721.610 (+-355.997)      |        4203.213 (+-134.362)        |           6423.763 (+-225.311)          |     1.528 (+-0.000)      |         6715.626 (+-288.233)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |      16695.467 (+-709.620)      |        3460.013 (+-149.456)        |           4001.810 (+-218.093)          |     1.157 (+-0.000)      |        16621.138 (+-713.320)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |       3020.017 (+-147.314)      |        4743.164 (+-135.850)        |           6709.494 (+-281.025)          |     1.415 (+-0.000)      |         3015.602 (+-105.852)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |      14456.688 (+-752.839)      |        5150.893 (+-201.571)        |           5737.315 (+-138.011)          |     1.114 (+-0.000)      |        14464.472 (+-720.027)

Times are in microseconds (us).
```

This PR fixes arrange decomp such that `arange(s, dtype=torch.float32)` directly provides better IR and generated code.

Code:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on this PR:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)
```
and C++ on this PR:
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123445

Note: Links to docs will display an error until the docs builds have been completed.

❌ 9 New Failures, 3 Unrelated Failures

As of commit e791f44 with merge base 5b0ce8f (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vfdev-5 added a commit that referenced this pull request Apr 5, 2024
Description:

Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64);  iota = None
        mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1);  convert_element_type = None
        add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10);  convert_element_type_1 = None
        return (add_1,)
```
and C++
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<double>(tmp0);   // <---- useless ops
            auto tmp2 = static_cast<double>(1.0);     // <----
            auto tmp3 = decltype(tmp1)(tmp1 * tmp2);  // <----
            auto tmp4 = static_cast<double>(0.0);     // <----
            auto tmp5 = decltype(tmp3)(tmp3 + tmp4);  // <----
            auto tmp6 = c10::convert<float>(tmp5);
            auto tmp7 = static_cast<float>(10.0);
            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
            out_ptr0[static_cast<long>(x0)] = tmp8;
        }
    }
}
```

However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up.
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s).to(dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:15 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)
```

C++ on `main`
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

For example, the speed-up seen on upsample_nearest2d on cpu:
```
[----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                |  Eager (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+git0d1e705) Nightly  |  speed-up PR vs Nightly  |  Eager (2.3.0a0+git0d1e705) Nightly
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |        287.988 (+-10.399)       |         200.034 (+-8.630)          |            285.143 (+-8.412)            |     1.425 (+-0.000)      |          287.991 (+-11.302)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |        697.206 (+-27.033)       |         171.650 (+-7.381)          |            193.280 (+-5.840)            |     1.126 (+-0.000)      |          701.642 (+-26.461)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        149.149 (+-6.045)        |         222.780 (+-6.852)          |            299.968 (+-12.354)           |     1.346 (+-0.000)      |          145.055 (+-7.232)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |        596.741 (+-27.970)       |         205.923 (+-8.648)          |            233.912 (+-7.742)            |     1.136 (+-0.000)      |          598.000 (+-25.630)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |       1095.734 (+-51.658)       |         700.850 (+-24.852)         |           1044.255 (+-38.216)           |     1.490 (+-0.000)      |         1097.977 (+-35.521)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |       2741.813 (+-122.917)      |         583.073 (+-16.998)         |            665.029 (+-36.331)           |     1.141 (+-0.000)      |         2722.388 (+-116.263)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        578.183 (+-37.266)       |         833.295 (+-42.264)         |           1131.341 (+-54.710)           |     1.358 (+-0.000)      |          584.953 (+-45.549)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |       2332.508 (+-103.556)      |         840.194 (+-47.664)         |            935.625 (+-47.467)           |     1.114 (+-0.000)      |         2334.314 (+-91.644)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |        272.631 (+-11.348)       |         195.988 (+-5.748)          |            274.021 (+-9.475)            |     1.398 (+-0.000)      |          272.752 (+-12.716)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |        640.409 (+-25.465)       |         164.773 (+-7.372)          |            185.018 (+-8.349)            |     1.123 (+-0.000)      |          639.390 (+-30.761)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |        158.602 (+-6.593)        |         220.478 (+-6.809)          |            286.376 (+-8.981)            |     1.299 (+-0.000)      |          158.557 (+-6.143)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |        548.903 (+-22.889)       |         202.788 (+-9.158)          |            227.404 (+-8.995)            |     1.121 (+-0.000)      |          554.096 (+-21.330)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |       1036.061 (+-35.285)       |         680.728 (+-30.925)         |            986.254 (+-42.732)           |     1.449 (+-0.000)      |         1038.718 (+-43.070)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |       2504.520 (+-125.805)      |         550.067 (+-21.383)         |            628.000 (+-27.589)           |     1.142 (+-0.000)      |         2523.134 (+-113.336)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |       1058.188 (+-57.853)       |        1216.427 (+-76.160)         |           1380.231 (+-98.939)           |     1.135 (+-0.000)      |         1057.031 (+-66.075)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |       2305.911 (+-116.864)      |        1080.189 (+-79.934)         |           1141.561 (+-67.959)           |     1.057 (+-0.000)      |         2306.606 (+-121.544)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       1689.489 (+-60.579)       |        1077.401 (+-44.948)         |           1634.264 (+-64.340)           |     1.517 (+-0.000)      |         1693.945 (+-67.998)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |       4198.368 (+-179.096)      |         886.656 (+-30.355)         |           1028.568 (+-46.310)           |     1.160 (+-0.000)      |         4174.351 (+-141.020)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |        716.572 (+-51.954)       |        1175.864 (+-52.191)         |           1674.373 (+-51.815)           |     1.424 (+-0.000)      |          715.724 (+-41.104)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |       3604.989 (+-132.489)      |        1096.933 (+-54.290)         |           1270.347 (+-60.932)           |     1.158 (+-0.000)      |         3601.864 (+-140.218)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       6721.610 (+-355.997)      |        4203.213 (+-134.362)        |           6423.763 (+-225.311)          |     1.528 (+-0.000)      |         6715.626 (+-288.233)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |      16695.467 (+-709.620)      |        3460.013 (+-149.456)        |           4001.810 (+-218.093)          |     1.157 (+-0.000)      |        16621.138 (+-713.320)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |       3020.017 (+-147.314)      |        4743.164 (+-135.850)        |           6709.494 (+-281.025)          |     1.415 (+-0.000)      |         3015.602 (+-105.852)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |      14456.688 (+-752.839)      |        5150.893 (+-201.571)        |           5737.315 (+-138.011)          |     1.114 (+-0.000)      |        14464.472 (+-720.027)

Times are in microseconds (us).
```

This PR fixes arrange decomp such that `arange(s, dtype=torch.float32)` directly provides better IR and generated code.

Code:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on this PR:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)
```
and C++ on this PR:
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

ghstack-source-id: b2cd195b3bed03784df46632e5f17d3ff7dabf13
Pull Request resolved: #123445
@vfdev-5 vfdev-5 marked this pull request as draft April 9, 2024 08:52
@vfdev-5 vfdev-5 closed this Apr 23, 2024
leslie-fang-intel added a commit that referenced this pull request Apr 28, 2024
…t per tensor and refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Apr 28, 2024
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 6, 2024
…t per tensor and refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 6, 2024
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 7, 2024
…t per tensor and refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 7, 2024
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request May 9, 2024
…uant pattern (#124041)

**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```

Pull Request resolved: #124041
Approved by: https://github.com/peterbell10, https://github.com/jgong5
leslie-fang-intel added a commit that referenced this pull request May 9, 2024
…t per tensor and refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 9, 2024
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request May 9, 2024
…uant pattern (#124041)

**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```

Pull Request resolved: #124041
Approved by: https://github.com/peterbell10, https://github.com/jgong5
@github-actions github-actions bot deleted the gh/vfdev-5/18/head branch May 30, 2024 02:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants