Skip to content

Commit

Permalink
Fixed arange decomp for float dtype
Browse files Browse the repository at this point in the history
Description:

Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64);  iota = None
        mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1);  convert_element_type = None
        add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10);  convert_element_type_1 = None
        return (add_1,)
```
and C++
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<double>(tmp0);   // <---- useless ops
            auto tmp2 = static_cast<double>(1.0);     // <----
            auto tmp3 = decltype(tmp1)(tmp1 * tmp2);  // <----
            auto tmp4 = static_cast<double>(0.0);     // <----
            auto tmp5 = decltype(tmp3)(tmp3 + tmp4);  // <----
            auto tmp6 = c10::convert<float>(tmp5);
            auto tmp7 = static_cast<float>(10.0);
            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
            out_ptr0[static_cast<long>(x0)] = tmp8;
        }
    }
}
```

However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up.
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s).to(dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:15 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)
```

C++ on `main`
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

For example, the speed-up seen on upsample_nearest2d on cpu:
```
[----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                |  Eager (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+git0d1e705) Nightly  |  speed-up PR vs Nightly  |  Eager (2.3.0a0+git0d1e705) Nightly
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |        287.988 (+-10.399)       |         200.034 (+-8.630)          |            285.143 (+-8.412)            |     1.425 (+-0.000)      |          287.991 (+-11.302)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |        697.206 (+-27.033)       |         171.650 (+-7.381)          |            193.280 (+-5.840)            |     1.126 (+-0.000)      |          701.642 (+-26.461)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        149.149 (+-6.045)        |         222.780 (+-6.852)          |            299.968 (+-12.354)           |     1.346 (+-0.000)      |          145.055 (+-7.232)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |        596.741 (+-27.970)       |         205.923 (+-8.648)          |            233.912 (+-7.742)            |     1.136 (+-0.000)      |          598.000 (+-25.630)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |       1095.734 (+-51.658)       |         700.850 (+-24.852)         |           1044.255 (+-38.216)           |     1.490 (+-0.000)      |         1097.977 (+-35.521)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |       2741.813 (+-122.917)      |         583.073 (+-16.998)         |            665.029 (+-36.331)           |     1.141 (+-0.000)      |         2722.388 (+-116.263)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        578.183 (+-37.266)       |         833.295 (+-42.264)         |           1131.341 (+-54.710)           |     1.358 (+-0.000)      |          584.953 (+-45.549)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |       2332.508 (+-103.556)      |         840.194 (+-47.664)         |            935.625 (+-47.467)           |     1.114 (+-0.000)      |         2334.314 (+-91.644)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |        272.631 (+-11.348)       |         195.988 (+-5.748)          |            274.021 (+-9.475)            |     1.398 (+-0.000)      |          272.752 (+-12.716)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |        640.409 (+-25.465)       |         164.773 (+-7.372)          |            185.018 (+-8.349)            |     1.123 (+-0.000)      |          639.390 (+-30.761)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |        158.602 (+-6.593)        |         220.478 (+-6.809)          |            286.376 (+-8.981)            |     1.299 (+-0.000)      |          158.557 (+-6.143)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |        548.903 (+-22.889)       |         202.788 (+-9.158)          |            227.404 (+-8.995)            |     1.121 (+-0.000)      |          554.096 (+-21.330)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |       1036.061 (+-35.285)       |         680.728 (+-30.925)         |            986.254 (+-42.732)           |     1.449 (+-0.000)      |         1038.718 (+-43.070)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |       2504.520 (+-125.805)      |         550.067 (+-21.383)         |            628.000 (+-27.589)           |     1.142 (+-0.000)      |         2523.134 (+-113.336)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |       1058.188 (+-57.853)       |        1216.427 (+-76.160)         |           1380.231 (+-98.939)           |     1.135 (+-0.000)      |         1057.031 (+-66.075)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |       2305.911 (+-116.864)      |        1080.189 (+-79.934)         |           1141.561 (+-67.959)           |     1.057 (+-0.000)      |         2306.606 (+-121.544)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       1689.489 (+-60.579)       |        1077.401 (+-44.948)         |           1634.264 (+-64.340)           |     1.517 (+-0.000)      |         1693.945 (+-67.998)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |       4198.368 (+-179.096)      |         886.656 (+-30.355)         |           1028.568 (+-46.310)           |     1.160 (+-0.000)      |         4174.351 (+-141.020)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |        716.572 (+-51.954)       |        1175.864 (+-52.191)         |           1674.373 (+-51.815)           |     1.424 (+-0.000)      |          715.724 (+-41.104)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |       3604.989 (+-132.489)      |        1096.933 (+-54.290)         |           1270.347 (+-60.932)           |     1.158 (+-0.000)      |         3601.864 (+-140.218)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       6721.610 (+-355.997)      |        4203.213 (+-134.362)        |           6423.763 (+-225.311)          |     1.528 (+-0.000)      |         6715.626 (+-288.233)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |      16695.467 (+-709.620)      |        3460.013 (+-149.456)        |           4001.810 (+-218.093)          |     1.157 (+-0.000)      |        16621.138 (+-713.320)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |       3020.017 (+-147.314)      |        4743.164 (+-135.850)        |           6709.494 (+-281.025)          |     1.415 (+-0.000)      |         3015.602 (+-105.852)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |      14456.688 (+-752.839)      |        5150.893 (+-201.571)        |           5737.315 (+-138.011)          |     1.114 (+-0.000)      |        14464.472 (+-720.027)

Times are in microseconds (us).
```

This PR fixes arrange decomp such that `arange(s, dtype=torch.float32)` directly provides better IR and generated code.

Code:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on this PR:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)
```
and C++ on this PR:
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

[ghstack-poisoned]
  • Loading branch information
vfdev-5 committed Apr 5, 2024
1 parent 86ddd88 commit e791f44
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 13 deletions.
13 changes: 7 additions & 6 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ def matcher_check_fn():
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, relu, mul_6, round_4, add_4 (optional),
# clamp_min_3, clamp_max_3, convert_element_type_6]
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 12)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 11)

self._test_common(
mod,
Expand Down Expand Up @@ -1730,15 +1730,16 @@ def forward(self, x):
# 2. Dequant-conv pattern matched in quantization weight prepack * 1
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. qconv2d_relu fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, relu, mul_2, round_2, add_1 (optional), clamp_min_1, clamp_max_1, convert_element_type_2]
# [qconv2d_pointwise_default, relu, mul_2, round_2, add_1 (optional), clamp_min_1, clamp_max_1,
# convert_element_type_2]
# 4. qmaxpool2d * 1
# [convert_element_type_3, sub_1 (optional), mul_3, max_pool2d_with_indices, getitem, mul_4, round_3, add_2 (optional),
# clamp_min_2, clamp_max_2, convert_element_type_4]
# [convert_element_type_3, sub_1 (optional), mul_3, max_pool2d_with_indices, getitem, mul_4,
# round_3, add_2 (optional), clamp_min_2, clamp_max_2, convert_element_type_4]
self._test_common(
mod,
(v,),
6,
31,
28,
check_quantization=True,
)

Expand Down Expand Up @@ -1827,7 +1828,7 @@ def forward(self, x):
mod,
(v,),
10,
49,
48,
check_quantization=True,
)

Expand Down
2 changes: 2 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4553,6 +4553,8 @@ def matmul_with_op(x, y, fn):
/ torch.ones(
[256, 256], dtype=torch.float32, device=x.device
), # noqa: E731
lambda x: x * 1.0 + 0.0, # noqa: E731
lambda x: x * 1 + 0, # noqa: E731
)

inps = [torch.rand([256, 256], device=self.device) for _ in range(2)]
Expand Down
40 changes: 40 additions & 0 deletions test/test_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import itertools
import functools
from functools import partial
import re
import unittest
import sys

Expand Down Expand Up @@ -631,6 +632,45 @@ def test_batch_norm_unflatten_weight_bias(self, device):
res = torch._decomp.decompositions.native_batch_norm(input, weight, bias, mean, var, False, 1, 1e-05)
self.assertEqual(shape, res[0].shape)

def test_arange_graph(self, device):
from torch.fx.experimental.proxy_tensor import make_fx

def func(x, start):
le = x.shape[-1]
if start is None:
a = torch.arange(le, dtype=torch.float32, device=x.device)
else:
a = torch.arange(start, le, dtype=torch.float32, device=x.device)
return a

pattern = r", device = device\(.+\), requires_grad = False"

cfunc = make_fx(func, decomposition_table=decomposition_table)
fx_g = cfunc(torch.rand(10, device=device), None)
fx_g_code = fx_g.code.strip()
# Remove device and requires_grad
fx_g_code = re.sub(pattern, "", fx_g_code)
self.assertExpectedInline(fx_g_code, """\
def forward(self, x_1, start_1):
iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64)
mul = torch.ops.prims.mul.default(iota, 1); iota = None
add = torch.ops.prims.add.default(mul, 0); mul = None
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
return convert_element_type""")

fx_g = cfunc(torch.rand(10, device=device), 1)
fx_g_code = fx_g.code.strip()
# Remove device and requires_grad
fx_g_code = re.sub(pattern, "", fx_g_code)
self.assertExpectedInline(fx_g_code, """\
def forward(self, x_1, start_1):
iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64)
mul = torch.ops.prims.mul.default(iota, 1); iota = None
add = torch.ops.prims.add.default(mul, 1); mul = None
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
return convert_element_type""")


class DecompCrossRefMode(TorchDispatchMode):
def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
self.test_case = test_case
Expand Down
16 changes: 12 additions & 4 deletions torch/_inductor/fx_passes/joint_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import typing
from collections import Counter
from typing import Dict, List, Set
from typing import Dict, List, Set, Union

import torch
import torch._guards
Expand Down Expand Up @@ -37,12 +37,18 @@ def lazy_init():

@torch.utils._python_dispatch._disable_current_modes()
def remove_no_ops(
gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node]
gm: torch.fx.GraphModule,
zeros: Set[Union[torch.fx.Node, int]],
ones: Set[Union[torch.fx.Node, int]],
):
"Removes no-ops: (+ 0, - 0, * 1, / 1)"
aten = torch.ops.aten
graph = gm.graph

# Add python builtins 0 and 1 to be able to remove from the graph the following ops: x * 1 + 0 -> x
zeros.add(0)
ones.add(1)

def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")):
if any(not isinstance(t, torch.Tensor) for t in (t1, t2)):
return False
Expand All @@ -57,7 +63,9 @@ def replace_no_op(node, replace_input_index):
# https://github.com/pytorch/pytorch/issues/86128 causes
# non-Tensor inputs even for ops with only Tensor inputs.
# TODO - decompose/type promote to avoid this
if not all(isinstance(arg, torch.fx.Node) for arg in node.args):
if not isinstance(replacement, torch.fx.Node):
return
if not all(isinstance(arg, (torch.fx.Node, int, float)) for arg in node.args):
return

if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
Expand Down Expand Up @@ -257,7 +265,7 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
):
torch._check(runtime_size == compile_time_size)

# zeros, and ones just get traced into full, so we insert those
# zeros and ones just get traced into full, so we insert those
new_node = graph.call_function(
aten.full.default,
args=(node_replacements_shapes[node], value),
Expand Down
10 changes: 7 additions & 3 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4932,9 +4932,10 @@ def is_finite(x):
lambda: f"step must be finite but got {step}",
)

args = (start, end, step)
integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)

if dtype is None:
args = (start, end, step)
integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)
dtype = torch.int64 if integer_args else torch.get_default_dtype()

is_integer = utils.is_integer_dtype(dtype)
Expand Down Expand Up @@ -4962,7 +4963,6 @@ def is_finite(x):
requires_grad=requires_grad,
)

computation_dtype = utils.get_acc_type(dtype, device)
index = prims.iota(
length,
start=0,
Expand All @@ -4971,6 +4971,10 @@ def is_finite(x):
device=device,
requires_grad=False,
)

computation_dtype = (
torch.long if integer_args else utils.get_acc_type(dtype, device)
)
index = _maybe_convert_to_dtype(index, computation_dtype)
result = start + step * index
result = _maybe_convert_to_dtype(result, dtype)
Expand Down

0 comments on commit e791f44

Please sign in to comment.