Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
"""Provide a visitor for lowering batched matmul (BMM) to TOSA."""

from typing import Any, List

import torch
Expand All @@ -30,6 +32,13 @@

@register_node_visitor
class BMMVisitor(NodeVisitor):
"""Provide a visitor that lowers ``aten.bmm`` to TOSA ``MATMUL``.

INT8 accumulates into INT32; add a rescale to INT8 using SINGLE_ROUND
rounding and output zero-point.

"""

target = "aten.bmm.default"

tosa_specs = [
Expand All @@ -47,7 +56,7 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:

"""Define the TOSA ``MATMUL`` operator and optional rescale."""
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
Expand Down
44 changes: 36 additions & 8 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
"""Provide a visitor for lowering 2D convolution to TOSA (INT/FP)."""

import itertools
from typing import Any, List

Expand All @@ -28,6 +30,12 @@

@register_node_visitor
class Conv2dVisitor(NodeVisitor):
"""Provide a visitor that lowers ``aten.convolution`` to TOSA.

Map to ``CONV2D`` or ``DEPTHWISE_CONV2D`` as appropriate.

"""

target = "aten.convolution.default"

tosa_specs = [
Expand All @@ -38,13 +46,32 @@ class Conv2dVisitor(NodeVisitor):
def __init__(self, *args):
super().__init__(*args)

# torch.nn.Conv2d does not require the result of
# `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
# to be an integer, but tosa currently strictly require this property.
# This function adjusts the pad value to meet the requirement.
def adjust_pad_if_needed(
self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int
) -> int:
"""Adjust padding to satisfy TOSA's integer output-size requirement.

Torch ``Conv2d`` does not require the result of
``(input + 2 * pad - dilation * (weight - 1) - 1) / stride`` to be an
integer, but TOSA does. This helper reduces the provided padding so
that the expression becomes divisible by ``stride``.

Args:
input_size (int): Spatial input size along the dimension (H or W).
input_weight (int): Kernel size along the same dimension.
stride (int): Stride along the same dimension.
pad (int): Padding value to adjust (bottom or right after duplication).
dilation (int): Dilation along the same dimension.

Returns:
int: Adjusted padding value that yields an integer output size.

Raises:
RuntimeError: If the required adjustment exceeds the provided
padding, which should be handled by the ``SizeAdjustInputPass``
pass instead.

"""
mod_remainder = (
input_size + 2 * pad - dilation * (input_weight - 1) - 1
) % stride
Expand All @@ -55,7 +82,8 @@ def adjust_pad_if_needed(

if mod_remainder > pad:
raise RuntimeError(
"This case should be handled by the SizeAdjustConv2d pass, is it enabled?"
"This case should be handled by the SizeAdjustInputPass pass, "
"is it enabled?"
)
return pad - mod_remainder

Expand All @@ -66,7 +94,7 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:

"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale."""
import serializer.tosa_serializer as ts # type: ignore
from tosa.RoundingMode import RoundingMode # type: ignore

Expand Down Expand Up @@ -133,7 +161,7 @@ def define_node(
in_channels = input.shape[1]
out_channels = weight.shape[0]
if (in_channels == group.number) and (out_channels % in_channels) == 0:
"""Depthwise convolution case"""
"""Depthwise convolution case."""
# Reshape torch shape format of weight tensor to tosa required format.
# https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
m_length = int(out_channels / in_channels)
Expand Down Expand Up @@ -178,7 +206,7 @@ def define_node(
acc_type=acc_type,
)
else:
"""Regular convolution case"""
"""Regular convolution case."""
tosa_op = ts.TosaOp.Op().CONV2D
weight_name = weight.name

Expand Down
Loading