Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

[inductor] Convolution channels last fails for TIMM workloads #1675

@anijain2305

Description

@anijain2305

Repro


import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

# torch version: 1.14.0a0+git25725fd
# torch cuda version: 11.6
# torch git version: 25725fd62448165b91647304c26d676db22b6955


# CUDA Info: 
# nvcc: NVIDIA (R) Cuda compiler driver 
# Copyright (c) 2005-2022 NVIDIA Corporation 
# Built on Thu_Feb_10_18:23:41_PST_2022 
# Cuda compilation tools, release 11.6, V11.6.112 
# Build cuda_11.6.r11.6/compiler.30978841_0 

# GPU Hardware Info: 
# NVIDIA A100-SXM4-40GB : 8 


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    
    
    def forward(self, arg9_1, getitem):
        convolution_3 = torch.ops.aten.convolution.default(getitem, arg9_1, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  getitem = arg9_1 = None
        return (convolution_3,)
        
args = [((80, 64, 1, 1), (64, 1, 64, 64), torch.float32, 'cuda'), ((128, 64, 73, 73), (341056, 1, 4672, 64), torch.float32, 'cuda')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro().to(device="cuda"))(*args)

from torchinductor.compile_fx import compile_fx_inner
from torchdynamo.debug_utils import same_two_models

compiled = compile_fx_inner(mod, args)
compiled(args)

Error

  File "/scratch/anijain/work/torchdynamo/torchdynamo/utils.py", line 76, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/anijain/work/torchdynamo/torchinductor/compile_fx.py", line 180, in cudagraphify
    return cudagraphify_impl(model, inputs, static_input_idxs)
  File "/scratch/anijain/work/torchdynamo/torchinductor/compile_fx.py", line 246, in cudagraphify_impl
    model(list(static_inputs))
  File "/tmp/torchinductor_anijain/jl/cjlibzuf5bvagnwdwvkdvfcn36per3zt4pg4qfg4ju7e7utoxccu.py", line 62, in call
    assert buf0.stride() == (426320, 5329, 73, 1)
AssertionError

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions