Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,51 @@ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "convolution")
def convolution(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: tuple[int, int],
padding: tuple[int, int],
dilation: tuple[int, int],
groups: int,
channel_last: bool = False,
) -> torch.Tensor:
conv_is_1d = len(input_tensor.shape) == 3
if channel_last:
if conv_is_1d:
input_tensor = input_tensor.movedim(-1, 1).contiguous()
if len(weight.shape) != 3:
raise ValueError("Weight tensor must be 3D if input is 3D")
weight = weight.movedim(-1, 1).contiguous()
else:
input_tensor = input_tensor.movedim(-1, -3)
if len(weight.shape) != 4:
raise ValueError("Weight tensor must be 4D if input is nd > 3")
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()

_stride: tuple[int, int] | int = stride
_padding: tuple[int, int] | int = padding
_dilation: tuple[int, int] | int = dilation
if conv_is_1d:
conv = torch.nn.functional.conv1d
_stride = stride[0]
_padding = padding[0]
_dilation = dilation[0]
else:
conv = torch.nn.functional.conv2d

conv_out = conv(input_tensor, weight, bias, _stride, _padding, _dilation, groups)
if channel_last:
if conv_is_1d:
conv_out = conv_out.movedim(1, -1).contiguous()
else:
conv_out = conv_out.movedim(-3, -1).contiguous()

return conv_out


def quantized_relu_common(
X: torch.Tensor,
X_zero_point: torch.Tensor | int,
Expand Down
277 changes: 277 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,3 +1256,280 @@ def test_rope(
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)

@expand(
[
# Test case 1: Basic 2D convolution (NCHW format)
(
"basic_2d_nchw",
torch.tensor(
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
), # input: 1x1x2x2
torch.tensor(
[[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32
), # weight: 1x1x2x2 (identity-like filter)
torch.tensor([0.0], dtype=torch.float32), # bias
(1, 1), # stride
(0, 0), # padding
(1, 1), # dilation
1, # groups
False, # channel_last
torch.tensor(
[[[[5.0]]]], dtype=torch.float32
), # expected: 1*1 + 4*1 = 5
),
# Test case 2: Basic 2D convolution (NHWC format)
(
"basic_2d_nhwc",
torch.tensor(
[[[[1.0], [2.0]], [[3.0], [4.0]]]], dtype=torch.float32
), # input: 1x2x2x1 (NHWC)
torch.tensor(
[[[[1.0], [0.0]], [[0.0], [1.0]]]], dtype=torch.float32
), # weight: 1x2x2x1 (NHWC format)
torch.tensor([0.0], dtype=torch.float32), # bias
(1, 1), # stride
(0, 0), # padding
(1, 1), # dilation
1, # groups
True, # channel_last
torch.tensor(
[[[[5.0]]]], dtype=torch.float32
), # expected: 1*1 + 4*1 = 5
),
# Test case 3: 2D convolution with stride=2
(
"conv2d_stride2",
torch.tensor(
[
[
[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
]
]
],
dtype=torch.float32,
), # input: 1x1x4x4
torch.tensor(
[[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32
), # weight: 1x1x2x2 (sum filter)
torch.tensor([0.0], dtype=torch.float32), # bias
(2, 2), # stride=2
(0, 0), # padding
(1, 1), # dilation
1, # groups
False, # channel_last
torch.tensor([[[[14.0, 22.0], [46.0, 54.0]]]], dtype=torch.float32),
),
# Test case 4: 2D convolution with padding=1
(
"conv2d_padding1",
torch.tensor(
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
), # input: 1x1x2x2
torch.tensor(
[[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32
), # weight: 1x1x2x2
torch.tensor([0.0], dtype=torch.float32), # bias
(1, 1), # stride
(1, 1), # padding=1
(1, 1), # dilation
1, # groups
False, # channel_last
torch.tensor(
[[[[1.0, 2.0, 0.0], [3.0, 5.0, 2.0], [0.0, 3.0, 4.0]]]],
dtype=torch.float32,
), # expected with padding
),
# Test case 5: 2D convolution with dilation=2
(
"conv2d_dilation2",
torch.tensor(
[
[
[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
]
]
],
dtype=torch.float32,
), # input: 1x1x4x4
torch.tensor(
[[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32
), # weight: 1x1x2x2
torch.tensor([0.0], dtype=torch.float32), # bias
(1, 1), # stride
(0, 0), # padding
(2, 2), # dilation=2
1, # groups
False, # channel_last
torch.tensor([[[[24.0, 28.0], [40.0, 44.0]]]], dtype=torch.float32),
),
# Test case 6: 2D grouped convolution (groups=2)
(
"conv2d_groups2",
torch.tensor(
[
[
[[1.0, 2.0], [3.0, 4.0]], # first input channel
[[5.0, 6.0], [7.0, 8.0]], # second input channel
]
],
dtype=torch.float32,
), # input: 1x2x2x2
torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]]], # first group weight
[[[0.5, 0.5], [0.5, 0.5]]], # second group weight
],
dtype=torch.float32,
), # weight: 2x1x2x2
torch.tensor([0.0, 1.0], dtype=torch.float32), # bias
(1, 1), # stride
(0, 0), # padding
(1, 1), # dilation
2, # groups=2
False, # channel_last
torch.tensor([[[[10.0]], [[14.0]]]], dtype=torch.float32),
),
# Test case 7: 1D convolution (NCL format)
(
"conv1d_ncl",
torch.tensor(
[[[1.0, 2.0, 3.0, 4.0]]], dtype=torch.float32
), # input: 1x1x4
torch.tensor([[[1.0, 1.0]]], dtype=torch.float32), # weight: 1x1x2
torch.tensor([0.0], dtype=torch.float32), # bias
(1, 1), # stride (only stride[1] is used for 1D)
(0, 0), # padding (only padding[1] is used for 1D)
(1, 1), # dilation (only dilation[1] is used for 1D)
1, # groups
False, # channel_last
torch.tensor(
[[[3.0, 5.0, 7.0]]], dtype=torch.float32
), # expected: [1+2, 2+3, 3+4]
),
# Test case 8: 1D convolution (NLC format)
(
"conv1d_nlc",
torch.tensor(
[[[1.0], [2.0], [3.0], [4.0]]], dtype=torch.float32
), # input: 1x4x1 (NLC)
torch.tensor(
[[[1.0], [1.0]]], dtype=torch.float32
), # weight: 1x2x1 (NLC)
torch.tensor([0.0], dtype=torch.float32), # bias
(1, 1), # stride
(0, 0), # padding
(1, 1), # dilation
1, # groups
True, # channel_last
torch.tensor([[[3.0], [5.0], [7.0]]], dtype=torch.float32),
),
# Test case 9: Multi-channel input and output
(
"multi_channel",
torch.tensor(
[
[
[[1.0, 2.0], [3.0, 4.0]], # first input channel
[[0.5, 1.0], [1.5, 2.0]], # second input channel
]
],
dtype=torch.float32,
), # input: 1x2x2x2
torch.tensor(
[
[ # first output channel
[[1.0, 0.0], [0.0, 1.0]], # weights for first input channel
[
[2.0, 0.0],
[0.0, 2.0],
], # weights for second input channel
],
[ # second output channel
[[0.5, 0.5], [0.5, 0.5]], # weights for first input channel
[
[1.0, 1.0],
[1.0, 1.0],
], # weights for second input channel
],
],
dtype=torch.float32,
), # weight: 2x2x2x2
torch.tensor([0.0, 1.0], dtype=torch.float32), # bias
(1, 1), # stride
(0, 0), # padding
(1, 1), # dilation
1, # groups
False, # channel_last
torch.tensor([[[[10.0]], [[11.0]]]], dtype=torch.float32),
),
# Test case 10: Convolution with non-zero bias
(
"conv2d_with_bias",
torch.tensor(
[[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32
), # input: 1x1x2x2
torch.tensor(
[[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32
), # weight: 1x1x2x2
torch.tensor([10.0], dtype=torch.float32), # bias=10
(1, 1), # stride
(0, 0), # padding
(1, 1), # dilation
1, # groups
False, # channel_last
torch.tensor(
[[[[15.0]]]], dtype=torch.float32
), # expected: 5 + 10 = 15
),
]
)
def test_convolution(
self,
name: str,
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: tuple[int, int],
padding: tuple[int, int],
dilation: tuple[int, int],
groups: int,
channel_last: bool,
expected_output: torch.Tensor,
) -> None:
output = torch.ops.cadence.convolution(
input_tensor,
weight,
bias,
stride,
padding,
dilation,
groups,
channel_last,
)

# Verify output properties
self.assertEqual(
output.dtype,
input_tensor.dtype,
f"Output dtype should match input dtype in {name}",
)
self.assertEqual(
output.shape,
expected_output.shape,
f"Output shape should match expected shape in {name}",
)

# Verify output matches expected values
self.assertTrue(
torch.equal(output, expected_output),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)
Loading