Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pt2] add metas for avg_pool3d and avg_pool3d_backward #103392

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2808,7 +2808,6 @@ def forward(self, x):
xfail('nn.functional.adaptive_max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbo...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2...
xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/...
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
xfail('nn.functional.bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.binary_cross_entropy', ''), # aten.fill_.Scalar - couldn't find symbolic meta funct...
Expand Down Expand Up @@ -3053,10 +3052,8 @@ def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
torch.nn.AdaptiveMaxPool3d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
# TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
torch.nn.LocalResponseNorm, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
torch.nn.AvgPool3d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.MaxPool1d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.MaxPool3d, # torch._subclasses.fake_tensor.UnsupportedOperatorException:
# aten.max_pool3d_with_indices.default
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,6 @@ def f(a, b, c, d, e):
xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl...
xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Expand Down
313 changes: 313 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,172 @@ def meta_avg_pool2d_backward(
)


@register_meta(aten.avg_pool3d)
@out_wrapper()
def meta_avg_pool3d(
input,
kernel_size,
stride=(),
padding=(0,),
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
check(
len(kernel_size) in (1, 3),
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
)
kT = kernel_size[0]
kH = kT if len(kernel_size) == 1 else kernel_size[1]
kW = kT if len(kernel_size) == 1 else kernel_size[2]

check(
not stride or len(stride) in (1, 3),
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
)
dT = kT if not stride else stride[0]
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])

check(
len(padding) in (1, 3),
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
)
padT = padding[0]
padH = padT if len(padding) == 1 else padding[1]
padW = padT if len(padding) == 1 else padding[2]

check(
input.ndim in (4, 5),
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
)

check(
not divisor_override or divisor_override != 0,
lambda: "divisor must be not zero",
)

nbatch = input.size(0)
nslices = input.size(-4)
itime = input.size(-3)
iheight = input.size(-2)
iwidth = input.size(-1)

otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)

pool3d_shape_check(
input,
nslices,
kT,
kH,
kW,
dT,
dH,
dW,
padT,
padH,
padW,
1,
1,
1,
itime,
iheight,
iwidth,
otime,
oheight,
owidth,
"avg_pool3d()",
check_input_size=True,
)

if input.ndim == 4:
return input.new_empty((nslices, otime, oheight, owidth))
else:
return input.new_empty((nbatch, nslices, otime, oheight, owidth))


@register_meta(aten.avg_pool3d_backward)
@out_wrapper()
def meta_avg_pool3d_backward(
grad_output,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
):
check(
len(kernel_size) in (1, 3),
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
)
kT = kernel_size[0]
kH = kT if len(kernel_size) == 1 else kernel_size[1]
kW = kT if len(kernel_size) == 1 else kernel_size[2]

check(
not stride or len(stride) in (1, 3),
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
)
dT = kT if not stride else stride[0]
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])

check(
len(padding) in (1, 3),
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
)
padT = padding[0]
padH = padT if len(padding) == 1 else padding[1]
padW = padT if len(padding) == 1 else padding[2]

check(
input.ndim in (4, 5),
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
)

check(
not divisor_override or divisor_override != 0,
lambda: "divisor must be not zero",
)

nslices = input.size(-4)
itime = input.size(-3)
iheight = input.size(-2)
iwidth = input.size(-1)

otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)

avg_pool3d_backward_shape_check(
input,
grad_output,
nslices,
kT,
kH,
kW,
dT,
dH,
dW,
padT,
padH,
padW,
itime,
iheight,
iwidth,
otime_for_shape_check,
oheight_for_shape_check,
owidth_for_shape_check,
"avg_pool3d_backward()",
)

return input.new_empty(input.shape)


@register_meta(aten._adaptive_avg_pool2d.default)
def meta_adaptive_avg_pool2d(self, output_size):
check(
Expand Down Expand Up @@ -2530,6 +2696,153 @@ def pool2d_shape_check(
)


def pool3d_shape_check(
input: Tensor,
nslices: int,
kT: int,
kH: int,
kW: int,
dT: int,
dH: int,
dW: int,
pT: int,
pH: int,
pW: int,
dilationT: int,
dilationH: int,
dilationW: int,
itime: int,
iheight: int,
iwidth: int,
otime: int,
oheight: int,
owidth: int,
fn_name: str,
check_input_size: bool = False,
):
ndim = input.ndim

check(
kT > 0 and kW > 0 and kH > 0,
lambda: (
f"kernel size should be greater than zero, but got "
f"kT: {kT}, kH: {kH}, kW: {kW}"
),
)
check(
dT > 0 and dW > 0 and dH > 0,
lambda: (
f"stride should be greater than zero, but got "
f"dT: {dT}, dH: {dH}, dW: {dW}"
),
)
check(
dilationT > 0 and dilationW > 0 and dilationH > 0,
lambda: (
f"dilation should be greater than zero, but got "
f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
),
)

check(
ndim in (4, 5),
lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
)

for i in range(ndim):
if ndim == 5 and i == 0:
# size of batch-dim can be 0.
continue
check(
input.size(i) > 0,
lambda: (
f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
f" but input has a shape of {input.shape}"
f" and non-batch dimension {input.size(i)} has length zero!"
),
)

if check_input_size: # AveragePool3d
check(
itime >= kT and iheight >= kH and iwidth >= kW,
lambda: (
f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
),
)

check(
kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
lambda: (
f"pad should be smaller than or equal to half of kernel size, but got "
f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
),
)

check(
otime >= 1 and owidth >= 1 and oheight >= 1,
lambda: (
f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
f"Output size is too small"
),
)


def avg_pool3d_backward_shape_check(
input: Tensor,
grad_output: Tensor,
nslices: int,
kT: int,
kH: int,
kW: int,
dT: int,
dH: int,
dW: int,
pT: int,
pH: int,
pW: int,
itime: int,
iheight: int,
iwidth: int,
otime: int,
oheight: int,
owidth: int,
fn_name: str,
):
ndim = input.ndim

pool3d_shape_check(
input,
nslices,
kT,
kH,
kW,
dT,
dH,
dW,
pT,
pH,
pW,
1,
1,
1,
itime,
iheight,
iwidth,
otime,
oheight,
owidth,
fn_name,
True,
)

check_dim_size(grad_output, ndim, ndim - 4, nslices)
check_dim_size(grad_output, ndim, ndim - 3, otime)
check_dim_size(grad_output, ndim, ndim - 2, oheight)
check_dim_size(grad_output, ndim, ndim - 1, owidth)


def max_pool2d_checks_and_compute_shape(
input, kernel_size, stride, padding, dilation, ceil_mode
):
Expand Down