Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinandrei committed Aug 17, 2023
2 parents d542a33 + 8298720 commit 19d91ad
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
# Don't run on forked repos.
if: github.repository_owner == 'pytorch'
runs-on: ${{ matrix.runner }}
environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'pytorchbot-env' || '' }}
strategy:
fail-fast: false
matrix:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/upload-test-stats.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ jobs:
if: ${{ always() }}
runs-on: [self-hosted, linux.2xlarge]
continue-on-error: true
environment: pytorchbot-env
steps:
- name: Get our GITHUB_TOKEN API limit usage
env:
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ Tensor batch_norm_elementwise_backward_train(
auto weight_nd = weight.defined() ? as_nd(weight) :
at::scalar_tensor(1.0, input.options().dtype(mean.scalar_type()));

Tensor grad_input = at::empty(input.sizes(), grad_out.options());
Tensor grad_input = at::empty(input.sizes(), grad_out.options().memory_format(input.suggest_memory_format()));
auto iter = TensorIteratorConfig()
.add_output(grad_input)
.add_input(grad_out)
Expand Down
36 changes: 36 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
)

from torch.testing._internal.inductor_utils import HAS_CUDA

torch.set_float32_matmul_precision("high")
Expand Down Expand Up @@ -214,6 +218,38 @@ def addmm(x, a, b):
with config.patch({"max_autotune": True}):
torch.compile(addmm, dynamic=dynamic)(x, a, b)

@skipIfRocm
def test_autotune_conv1x1(self):
# Define the 1x1 convolutional layer
# Assuming input has 3 channels and we want to produce 16 channels as output
conv1x1 = (
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1)
.to(memory_format=torch.channels_last)
.cuda()
)

# Example input tensor: batch size = 4, channels = 3, height = 32, width = 32
# The memory format is set to `channels_last`
input_tensor = (
torch.randn(4, 3, 32, 32)
.contiguous(memory_format=torch.channels_last)
.cuda()
)

with config.patch(
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
):

@torch.compile()
def foo(mod, x):
return mod(x)

with torch.no_grad():
out, code = run_and_get_code(foo, conv1x1, input_tensor)

FileCheck().check_not("extern_kernels.convolution").run(code[0])
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)

def test_cat_addmm(self):
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
return torch.cat(
Expand Down
14 changes: 13 additions & 1 deletion torch/_inductor/kernel/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,20 @@ def convolution(
dilation = pad_listlike(dilation, ndim)
output_padding = pad_listlike(output_padding, ndim)

def channels_last_conv():
if V.graph.layout_opt and ndim == 2:
return True

layout = conv_layout(x, weight, None, **kwargs)
req_stride_order = ir.get_stride_order(
V.graph.sizevars.size_hints(layout.stride)
)
return req_stride_order == ir.NHWC_STRIDE_ORDER

autotuning_gemm = config.max_autotune or config.max_autotune_gemm

if (
config.conv_1x1_as_mm
(config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
and is_ones(kernel_shape)
and is_ones(stride)
and is_zeros(padding)
Expand Down

0 comments on commit 19d91ad

Please sign in to comment.