Skip to content

Commit

Permalink
[inductor] Fixed conv issue with dynamic shapes (#114351)
Browse files Browse the repository at this point in the history
EDIT: fixes #114354

Description:
The following code is failing:
```python
import torch

def func(x, w):
    return torch.nn.functional.conv2d(x, w, groups=int(w.shape[0]))

x = torch.rand(1, 3, 64, 64)
w = torch.rand(3, 1, 3, 3)
y1 = func(x, w)
cfunc = torch.compile(func, fullgraph=True, dynamic=True)
y2 = cfunc(x, w)

torch.testing.assert_close(y1, y2)
```
with the error:
```
  File "/pytorch/torch/_inductor/kernel/conv.py", line 315, in convolution
    assert isinstance(groups, int)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError:
  target: aten.convolution.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cpu', torch.float32, size=[1, s0, s1, s1], stride=[s0*s1**2, s1**2, s1, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cpu', torch.float32, size=[s0, 1, s0, s0], stride=[s0**2, s0**2, s0, 1]))
  ))
  args[2]: None
  args[3]: [1, 1]
  args[4]: [0, 0]
  args[5]: [1, 1]
  args[6]: False
  args[7]: [0, 0]
  args[8]: s0
```
where `groups` argument is a symbol but expected to be `int`.

This PR specializes `group` to its int value and fixes the problem.

Context: Failing tests in torchvision with gaussian blur and adjust_sharpness ops
- https://github.com/pytorch/vision/actions/runs/6955843968/job/18926393710?pr=8127

Pull Request resolved: #114351
Approved by: https://github.com/ezyang
  • Loading branch information
vfdev-5 authored and pytorchmergebot committed Nov 23, 2023
1 parent 01366ef commit 85aa372
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,20 @@ def test_convolution3(self):
rtol=0.001,
)

@skipIfRocm
def test_convolution4(self):
def fn(x, w):
x = F.conv2d(x, w, groups=w.shape[0])
return x.sum()

self.common(
fn,
(
torch.randn([2, 3, 16, 20]),
torch.randn([3, 1, 5, 5]),
),
)

def test_conv2d_channels_last(self):
if self.device == "cuda":
raise unittest.SkipTest("only support cpu conv2d channels_last")
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/kernel/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def convolution(
padding = tuple(padding)
dilation = tuple(dilation)
output_padding = tuple(output_padding)
if not isinstance(groups, int):
groups = V.graph.sizevars.evaluate_static_shape(groups)
assert isinstance(groups, int)
kwargs: ConvLayoutParams = {
"stride": stride,
Expand Down

0 comments on commit 85aa372

Please sign in to comment.