Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .decompose_maxpool3d import DecomposeMaxPool3d
from .decompose_minmaxdim import DecomposeMinMaxDim
from .decompose_reciprocal import DecomposeReciprocal
from .decompose_remainder import DecomposeRemainder
from .decompose_roll import DecomposeRoll
from .decompose_silu import DecomposeSilu
from .decompose_threshold import DecomposeThreshold
Expand Down Expand Up @@ -80,6 +81,7 @@
DecomposeMaxPool3d,
DecomposeMinMaxDim,
DecomposeReciprocal,
DecomposeRemainder,
DecomposeRoll,
DecomposeSilu,
DecomposeThreshold,
Expand Down
103 changes: 103 additions & 0 deletions backends/qualcomm/_passes/decompose_remainder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix

from .utils import copy_meta, get_const_node


class DecomposeRemainder(ExportPass):
"""
Decompose remainder.Scalar and remainder.Tensor using the identity:
remainder(x, y) = x - floor(x / y) * y
"""

def __init__(self):
super(DecomposeRemainder, self).__init__()
self.remainder_targets = {
torch.ops.aten.remainder.Scalar,
torch.ops.aten.remainder.Tensor,
exir_ops.edge.aten.remainder.Scalar,
exir_ops.edge.aten.remainder.Tensor,
}

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
# Cache scalar:node mappings to avoid duplicate buffer registrations if the same scalar divisor appears in multiple remainder ops
const_cache = {}

for node in list(graph.nodes):
if node.op == "call_function" and node.target in self.remainder_targets:
x_node = node.args[0]
y_arg = node.args[1]
is_edge = isinstance(node.target, EdgeOpOverload)
meta = node.meta

div_op = (
exir_ops.edge.aten.div.Tensor
if is_edge
else torch.ops.aten.div.Tensor
)
floor_op = (
exir_ops.edge.aten.floor.default
if is_edge
else torch.ops.aten.floor.default
)
mul_op = (
exir_ops.edge.aten.mul.Tensor
if is_edge
else torch.ops.aten.mul.Tensor
)
sub_op = (
exir_ops.edge.aten.sub.Tensor
if is_edge
else torch.ops.aten.sub.Tensor
)

is_scalar = not isinstance(y_arg, torch.fx.Node)
if is_scalar and is_edge:
if y_arg not in const_cache:
attr_name = get_new_attr_name_with_prefix("_remainder_const_")(
graph_module
)
const_cache[y_arg] = get_const_node(
graph, graph_module, attr_name, y_arg, node
)
y_node = const_cache[y_arg]
else:
y_node = y_arg

with graph.inserting_before(node):
div_node = graph.create_node(
"call_function", div_op, (x_node, y_node)
)
div_node.meta = copy_meta(meta)

floor_node = graph.create_node(
"call_function", floor_op, (div_node,)
)
floor_node.meta = copy_meta(meta)

mul_node = graph.create_node(
"call_function", mul_op, (floor_node, y_node)
)
mul_node.meta = copy_meta(meta)

sub_node = graph.create_node(
"call_function", sub_op, (x_node, mul_node)
)
sub_node.meta = copy_meta(meta)

for user in node.users.copy():
user.replace_input_with(node, sub_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
DecomposeMaxPool3d,
DecomposeMinMaxDim,
DecomposeReciprocal,
DecomposeRemainder,
DecomposeRoll,
DecomposeSilu,
DecomposeThreshold,
Expand Down Expand Up @@ -106,6 +107,7 @@ def get_capture_program_passes():
(DecomposeLogVariants, True),
(DecomposeMaxPool3d, True),
(DecomposeMinMaxDim, True),
(DecomposeRemainder, True),
(DecomposeTrunc, True),
(ExpandBroadcastTensorShape, True),
(FixedLinearKeepDim, True),
Expand Down Expand Up @@ -239,6 +241,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
# Decompose Reciprocal into Div for these 2 backend
# TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager)
self.add_pass(DecomposeReciprocal())
self.add_pass(DecomposeRemainder())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(DecomposeLogVariants())
self.add_pass(ReplaceInfValues())
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_passes_dependency_for_capture_program():
DecomposeLinalgVectorNorm,
DecomposeLogVariants,
DecomposeMaxPool3d,
DecomposeRemainder,
DecomposeTrunc,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down Expand Up @@ -101,6 +102,7 @@ def get_passes_dependency_for_capture_program():
DecomposeLinalgVectorNorm: [RemoveRedundancy],
DecomposeLogVariants: [RemoveRedundancy],
DecomposeMaxPool3d: [RemoveRedundancy],
DecomposeRemainder: [RemoveRedundancy],
DecomposeTrunc: [RemoveRedundancy],
ExpandBroadcastTensorShape: [FoldQDQ],
FixedLinearKeepDim: [FoldQDQ],
Expand Down
24 changes: 24 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,30 @@ def forward(self, x):
return x.repeat(1, 2, 3, 4)


class RemainderScalar(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.remainder(x, 3.0)


class RemainderTensor(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.remainder(x, y)


class RemainderMultiNode(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.remainder(x, 3.0), torch.remainder(x, y)


class ReWriteObs(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading