Skip to content

Commit 0c787df

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Support im2row (#14729)
Summary: Continued support of cadence ops. Added support for both im2row and im2row_per_tensor. Differential Revision: D83620790
1 parent 6e991bf commit 0c787df

File tree

2 files changed

+347
-0
lines changed

2 files changed

+347
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,3 +1303,100 @@ def rope(
13031303
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
13041304
)
13051305
return rotated.view(original_shape)
1306+
1307+
1308+
@impl(m, "im2row")
1309+
def im2row(
1310+
input_tensor: torch.Tensor,
1311+
kernel_size: tuple[int, int],
1312+
dilation: tuple[int, int],
1313+
padding: tuple[int, int],
1314+
stride: tuple[int, int],
1315+
in_zero_point: torch.Tensor,
1316+
channel_last: bool = False,
1317+
) -> torch.Tensor:
1318+
"""
1319+
Converts an input tensor into a 2D matrix where each row is a flattened sliding window (patch)
1320+
from the input, suitable for use in convolution as a matrix multiplication (im2row).
1321+
1322+
Args:
1323+
- input_tensor: Input tensor of shape (N, C, H, W) or (N, H, W, C) if channel_last.
1324+
- kernel_size: Size of the convolution kernel.
1325+
- dilation: Dilation of the convolution kernel.
1326+
- padding: Padding to apply to the input.
1327+
- stride: Stride of the convolution.
1328+
- in_zero_point : Zero point for input quantization (broadcastable to input).
1329+
- channel_last: If True, input is in NHWC format, else NCHW.
1330+
1331+
Returns:
1332+
- Tensor of shape (N, num_patches, patch_size)
1333+
"""
1334+
# Move to NCHW for processing if needed
1335+
if in_zero_point is not None:
1336+
if in_zero_point.numel() != 1 and in_zero_point.shape != (
1337+
input_tensor.shape[0],
1338+
):
1339+
raise ValueError(
1340+
f"Input zero point must be a scalar or broadcastable to input shape {input_tensor.shape}"
1341+
)
1342+
if in_zero_point.dtype != torch.int32:
1343+
raise ValueError("Input zero point must be an int32 tensor")
1344+
1345+
if channel_last:
1346+
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW
1347+
1348+
N, C, H, W = input_tensor.shape
1349+
kH, kW = kernel_size
1350+
dH, dW = dilation
1351+
pH, pW = padding
1352+
sH, sW = stride
1353+
1354+
# Subtract in_zero_point if needed (broadcast)
1355+
if in_zero_point is not None:
1356+
if len(in_zero_point.shape) == 1: # If shape is (), we skip this logic
1357+
in_zero_point = in_zero_point.reshape([-1] + [1] * (input_tensor.dim() - 1))
1358+
1359+
input_tensor = (input_tensor - in_zero_point).to(input_tensor.dtype)
1360+
1361+
# Use unfold to extract sliding local blocks
1362+
# Unfold: (N, C, H, W) -> (N, C, L, kH, kW), where L = number of sliding windows
1363+
# torch.nn.functional.unfold returns (N, C*kH*kW, L)
1364+
patches = torch.nn.functional.unfold(
1365+
input_tensor.float(), # unfold not implemented for int
1366+
kernel_size=(kH, kW),
1367+
dilation=(dH, dW),
1368+
padding=(pH, pW),
1369+
stride=(sH, sW),
1370+
).to(
1371+
input_tensor.dtype
1372+
) # (N, C*kH*kW, L)
1373+
1374+
# Transpose to (N, L, C*kH*kW)
1375+
patches = patches.transpose(1, 2).contiguous()
1376+
1377+
# Reshape to (N*L, C*kH*kW)
1378+
patches = patches.view(-1, C * kH * kW)
1379+
1380+
# If channel_last, output should be in NHWC patch order (but im2row is always row-major)
1381+
return patches
1382+
1383+
1384+
@impl(m, "im2row.per_tensor")
1385+
def im2row_per_tensor(
1386+
input_tensor: torch.Tensor,
1387+
kernel_size: tuple[int, int],
1388+
dilation: tuple[int, int],
1389+
padding: tuple[int, int],
1390+
stride: tuple[int, int],
1391+
in_zero_point: int,
1392+
channel_last: bool = False,
1393+
) -> torch.Tensor:
1394+
return im2row(
1395+
input_tensor,
1396+
kernel_size,
1397+
dilation,
1398+
padding,
1399+
stride,
1400+
torch.tensor(in_zero_point, dtype=torch.int32),
1401+
channel_last,
1402+
)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,3 +1843,253 @@ def test_avg_pool2d(
18431843
torch.equal(output, expected_output),
18441844
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
18451845
)
1846+
1847+
@expand(
1848+
[
1849+
# Basic 2x2 kernel, stride 1, no padding, NCHW
1850+
(
1851+
"nchw_basic_2x2",
1852+
torch.tensor(
1853+
[[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32
1854+
), # (N=1, C=1, H=3, W=3)
1855+
(2, 2), # kernel_size
1856+
(1, 1), # dilation
1857+
(0, 0), # padding
1858+
(1, 1), # stride
1859+
None, # in_zero_point
1860+
False, # channel_last
1861+
False,
1862+
torch.tensor(
1863+
[
1864+
[1, 2, 4, 5],
1865+
[2, 3, 5, 6],
1866+
[4, 5, 7, 8],
1867+
[5, 6, 8, 9],
1868+
],
1869+
dtype=torch.float32,
1870+
),
1871+
),
1872+
# 2x2 kernel, stride 2, no padding, NCHW
1873+
(
1874+
"nchw_stride2",
1875+
torch.tensor(
1876+
[[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32
1877+
),
1878+
(2, 2),
1879+
(1, 1),
1880+
(0, 0),
1881+
(2, 2),
1882+
None,
1883+
False,
1884+
False,
1885+
torch.tensor(
1886+
[
1887+
[1, 2, 4, 5],
1888+
],
1889+
dtype=torch.float32, # Only every other patch in each dim
1890+
),
1891+
),
1892+
# 2x2 kernel, stride 1, padding 1, NCHW
1893+
(
1894+
"nchw_padding1",
1895+
torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.float32), # (1,1,2,2)
1896+
(2, 2),
1897+
(1, 1),
1898+
(1, 1),
1899+
(1, 1),
1900+
None,
1901+
False,
1902+
False,
1903+
torch.tensor(
1904+
[
1905+
[0, 0, 0, 1],
1906+
[0, 0, 1, 2],
1907+
[0, 0, 2, 0],
1908+
[0, 1, 0, 3],
1909+
[1, 2, 3, 4],
1910+
[2, 0, 4, 0],
1911+
[0, 3, 0, 0],
1912+
[3, 4, 0, 0],
1913+
[4, 0, 0, 0],
1914+
],
1915+
dtype=torch.float32,
1916+
),
1917+
),
1918+
# 2x2 kernel, stride 1, no padding, NHWC
1919+
(
1920+
"nhwc_basic_2x2",
1921+
torch.tensor(
1922+
[[[[1], [2], [3]], [[4], [5], [6]], [[7], [8], [9]]]],
1923+
dtype=torch.float32,
1924+
), # (N=1, H=3, W=3, C=1)
1925+
(2, 2),
1926+
(1, 1),
1927+
(0, 0),
1928+
(1, 1),
1929+
None,
1930+
True,
1931+
False,
1932+
torch.tensor(
1933+
[
1934+
[1, 2, 4, 5],
1935+
[2, 3, 5, 6],
1936+
[4, 5, 7, 8],
1937+
[5, 6, 8, 9],
1938+
],
1939+
dtype=torch.float32,
1940+
),
1941+
),
1942+
# 2x2 kernel, stride 1, no padding, NCHW, in_zero_point=1
1943+
(
1944+
"nchw_in_zero_point",
1945+
torch.tensor([[[[2, 3, 4], [5, 6, 7], [8, 9, 10]]]], dtype=torch.int8),
1946+
(2, 2),
1947+
(1, 1),
1948+
(0, 0),
1949+
(1, 1),
1950+
torch.tensor(1, dtype=torch.int32),
1951+
False,
1952+
False,
1953+
torch.tensor(
1954+
[
1955+
[1, 2, 4, 5],
1956+
[2, 3, 5, 6],
1957+
[4, 5, 7, 8],
1958+
[5, 6, 8, 9],
1959+
],
1960+
dtype=torch.int8,
1961+
),
1962+
),
1963+
# 2x2 kernel, stride 1, no padding, NHWC, in_zero_point=2
1964+
(
1965+
"nhwc_in_zero_point",
1966+
torch.tensor(
1967+
[[[[3], [4], [5]], [[6], [7], [8]], [[9], [10], [11]]]],
1968+
dtype=torch.int8,
1969+
),
1970+
(2, 2),
1971+
(1, 1),
1972+
(0, 0),
1973+
(1, 1),
1974+
torch.tensor(2, dtype=torch.int32),
1975+
True,
1976+
False,
1977+
torch.tensor(
1978+
[
1979+
[1, 2, 4, 5],
1980+
[2, 3, 5, 6],
1981+
[4, 5, 7, 8],
1982+
[5, 6, 8, 9],
1983+
],
1984+
dtype=torch.int8,
1985+
),
1986+
),
1987+
# Multi-channel input, 2x2 kernel, stride 1, no padding, NCHW
1988+
(
1989+
"nchw_multi_channel",
1990+
torch.tensor(
1991+
[
1992+
[
1993+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], # channel 0
1994+
[[10, 11, 12], [13, 14, 15], [16, 17, 18]], # channel 1
1995+
]
1996+
],
1997+
dtype=torch.float32,
1998+
), # (1,2,3,3)
1999+
(2, 2),
2000+
(1, 1),
2001+
(0, 0),
2002+
(1, 1),
2003+
None,
2004+
False,
2005+
False,
2006+
torch.tensor(
2007+
[
2008+
[1, 2, 4, 5, 10, 11, 13, 14],
2009+
[2, 3, 5, 6, 11, 12, 14, 15],
2010+
[4, 5, 7, 8, 13, 14, 16, 17],
2011+
[5, 6, 8, 9, 14, 15, 17, 18],
2012+
],
2013+
dtype=torch.float32,
2014+
),
2015+
),
2016+
# Multi-channel input and multi-channel zero-point
2017+
(
2018+
"nchw_multi_channel_and_zero_point",
2019+
torch.tensor([[[[1, 2, 3]]], [[[4, 5, 6]]]], dtype=torch.int32),
2020+
(1, 2),
2021+
(1, 1),
2022+
(0, 0),
2023+
(1, 1),
2024+
torch.tensor([1, 2], dtype=torch.int32),
2025+
False,
2026+
False,
2027+
torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]], dtype=torch.int32),
2028+
),
2029+
(
2030+
"per_tensor",
2031+
torch.tensor(
2032+
[[[[3], [4], [5]], [[6], [7], [8]], [[9], [10], [11]]]],
2033+
dtype=torch.int8,
2034+
),
2035+
(2, 2),
2036+
(1, 1),
2037+
(0, 0),
2038+
(1, 1),
2039+
2,
2040+
True,
2041+
True,
2042+
torch.tensor(
2043+
[
2044+
[1, 2, 4, 5],
2045+
[2, 3, 5, 6],
2046+
[4, 5, 7, 8],
2047+
[5, 6, 8, 9],
2048+
],
2049+
dtype=torch.int8,
2050+
),
2051+
),
2052+
]
2053+
)
2054+
def test_im2row(
2055+
self,
2056+
name: str,
2057+
input_tensor: torch.Tensor,
2058+
kernel_size: tuple[int, int],
2059+
dilation: tuple[int, int],
2060+
padding: tuple[int, int],
2061+
stride: tuple[int, int],
2062+
in_zero_point: torch.Tensor | None,
2063+
channel_last: bool,
2064+
per_tensor: bool,
2065+
expected_output: torch.Tensor,
2066+
) -> None:
2067+
if per_tensor:
2068+
output = torch.ops.cadence.im2row.per_tensor(
2069+
input_tensor,
2070+
kernel_size,
2071+
dilation,
2072+
padding,
2073+
stride,
2074+
in_zero_point,
2075+
channel_last,
2076+
)
2077+
else:
2078+
output = torch.ops.cadence.im2row(
2079+
input_tensor,
2080+
kernel_size,
2081+
dilation,
2082+
padding,
2083+
stride,
2084+
in_zero_point,
2085+
channel_last,
2086+
)
2087+
self.assertEqual(
2088+
output.shape,
2089+
expected_output.shape,
2090+
f"im2row output shape mismatch in {name}",
2091+
)
2092+
self.assertTrue(
2093+
torch.equal(output, expected_output),
2094+
f"im2row output mismatch in {name}: got {output}, expected {expected_output}",
2095+
)

0 commit comments

Comments
 (0)