Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
node for node in partition.nodes if node.target in matmul_targets
][0]

if quantized_input:
if quantized_input and not all(
input_node.target in DQ_OPS
for input_node in matmul_node.all_input_nodes
):
matmul_args = matmul_node.all_input_nodes
for node in matmul_args:
# Find the dq-node connected to this mm/bmm arg
Expand All @@ -96,7 +99,9 @@ def call(self, graph_module: GraphModule) -> PassResult:

partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target in Q_OPS
if quantized_output:
if quantized_output and not all(
user.target in Q_OPS for user in matmul_node.users
):
with graph_module.graph.inserting_after(matmul_node):
# Create q-node after matmul
q_node = create_node(
Expand Down
7 changes: 7 additions & 0 deletions backends/arm/test/ops/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

class MatMul(torch.nn.Module):
test_data_generators = {
"rand_rand_2d": lambda: (torch.rand(5, 5), torch.rand(5, 2)),
"rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
"rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
}
Expand All @@ -32,6 +33,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):

class MatMulSingleInput(torch.nn.Module):
test_data_generators = {
"rand_2d": lambda: (torch.rand(5, 5),),
"rand_3d": lambda: (torch.rand(2, 5, 5),),
"rand_4d": lambda: (torch.rand(1, 2, 5, 5),),
}
Expand All @@ -42,6 +44,11 @@ def forward(self, x: torch.Tensor):

class MatMulCombo(torch.nn.Module):
test_data_generators = {
"rand_rand_rand_2d": lambda: (
torch.rand(5, 5),
torch.rand(5, 2),
torch.rand(2, 5),
),
"rand_rand_rand_3d": lambda: (
torch.rand(2, 5, 5),
torch.rand(2, 5, 2),
Expand Down
Loading