This repository was archived by the owner on Aug 1, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 129
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
[inductor] Convolution channels last fails for TIMM workloads #1675
Copy link
Copy link
Closed
Labels
Description
Repro
import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
# torch version: 1.14.0a0+git25725fd
# torch cuda version: 11.6
# torch git version: 25725fd62448165b91647304c26d676db22b6955
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0
# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg9_1, getitem):
convolution_3 = torch.ops.aten.convolution.default(getitem, arg9_1, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); getitem = arg9_1 = None
return (convolution_3,)
args = [((80, 64, 1, 1), (64, 1, 64, 64), torch.float32, 'cuda'), ((128, 64, 73, 73), (341056, 1, 4672, 64), torch.float32, 'cuda')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro().to(device="cuda"))(*args)
from torchinductor.compile_fx import compile_fx_inner
from torchdynamo.debug_utils import same_two_models
compiled = compile_fx_inner(mod, args)
compiled(args)
Error
File "/scratch/anijain/work/torchdynamo/torchdynamo/utils.py", line 76, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/anijain/work/torchdynamo/torchinductor/compile_fx.py", line 180, in cudagraphify
return cudagraphify_impl(model, inputs, static_input_idxs)
File "/scratch/anijain/work/torchdynamo/torchinductor/compile_fx.py", line 246, in cudagraphify_impl
model(list(static_inputs))
File "/tmp/torchinductor_anijain/jl/cjlibzuf5bvagnwdwvkdvfcn36per3zt4pg4qfg4ju7e7utoxccu.py", line 62, in call
assert buf0.stride() == (426320, 5329, 73, 1)
AssertionError