Skip to content

Commit

Permalink
Enable Lowering Channels last Conv1x1 when max autotune is set (#107004)
Browse files Browse the repository at this point in the history
This can lead to a large speedup when max autotune is set, e.g. resnet 2.1x -> 2.5x, particularly in combination with freezing.

Pull Request resolved: #107004
Approved by: https://github.com/jansel, https://github.com/shunting314, https://github.com/int3
ghstack dependencies: #106911, #106912
  • Loading branch information
eellison authored and pytorchmergebot committed Aug 17, 2023
1 parent f96617f commit 8298720
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
36 changes: 36 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
)

from torch.testing._internal.inductor_utils import HAS_CUDA

torch.set_float32_matmul_precision("high")
Expand Down Expand Up @@ -214,6 +218,38 @@ def addmm(x, a, b):
with config.patch({"max_autotune": True}):
torch.compile(addmm, dynamic=dynamic)(x, a, b)

@skipIfRocm
def test_autotune_conv1x1(self):
# Define the 1x1 convolutional layer
# Assuming input has 3 channels and we want to produce 16 channels as output
conv1x1 = (
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1)
.to(memory_format=torch.channels_last)
.cuda()
)

# Example input tensor: batch size = 4, channels = 3, height = 32, width = 32
# The memory format is set to `channels_last`
input_tensor = (
torch.randn(4, 3, 32, 32)
.contiguous(memory_format=torch.channels_last)
.cuda()
)

with config.patch(
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
):

@torch.compile()
def foo(mod, x):
return mod(x)

with torch.no_grad():
out, code = run_and_get_code(foo, conv1x1, input_tensor)

FileCheck().check_not("extern_kernels.convolution").run(code[0])
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)

def test_cat_addmm(self):
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
return torch.cat(
Expand Down
14 changes: 13 additions & 1 deletion torch/_inductor/kernel/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,20 @@ def convolution(
dilation = pad_listlike(dilation, ndim)
output_padding = pad_listlike(output_padding, ndim)

def channels_last_conv():
if V.graph.layout_opt and ndim == 2:
return True

layout = conv_layout(x, weight, None, **kwargs)
req_stride_order = ir.get_stride_order(
V.graph.sizevars.size_hints(layout.stride)
)
return req_stride_order == ir.NHWC_STRIDE_ORDER

autotuning_gemm = config.max_autotune or config.max_autotune_gemm

if (
config.conv_1x1_as_mm
(config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
and is_ones(kernel_shape)
and is_ones(stride)
and is_zeros(padding)
Expand Down

0 comments on commit 8298720

Please sign in to comment.