diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 81b7b36cc0b..6a1acb8d71d 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -70,7 +70,10 @@ def call(self, graph_module: GraphModule) -> PassResult: node for node in partition.nodes if node.target in matmul_targets ][0] - if quantized_input: + if quantized_input and not all( + input_node.target in DQ_OPS + for input_node in matmul_node.all_input_nodes + ): matmul_args = matmul_node.all_input_nodes for node in matmul_args: # Find the dq-node connected to this mm/bmm arg @@ -96,7 +99,9 @@ def call(self, graph_module: GraphModule) -> PassResult: partition_output = list(partition.output_nodes[0].users)[0] quantized_output = partition_output.target in Q_OPS - if quantized_output: + if quantized_output and not all( + user.target in Q_OPS for user in matmul_node.users + ): with graph_module.graph.inserting_after(matmul_node): # Create q-node after matmul q_node = create_node( diff --git a/backends/arm/test/ops/test_matmul.py b/backends/arm/test/ops/test_matmul.py index d1a21684325..1486dc97f4c 100644 --- a/backends/arm/test/ops/test_matmul.py +++ b/backends/arm/test/ops/test_matmul.py @@ -22,6 +22,7 @@ class MatMul(torch.nn.Module): test_data_generators = { + "rand_rand_2d": lambda: (torch.rand(5, 5), torch.rand(5, 2)), "rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), "rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), } @@ -32,6 +33,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): class MatMulSingleInput(torch.nn.Module): test_data_generators = { + "rand_2d": lambda: (torch.rand(5, 5),), "rand_3d": lambda: (torch.rand(2, 5, 5),), "rand_4d": lambda: (torch.rand(1, 2, 5, 5),), } @@ -42,6 +44,11 @@ def forward(self, x: torch.Tensor): class MatMulCombo(torch.nn.Module): test_data_generators = { + "rand_rand_rand_2d": lambda: ( + torch.rand(5, 5), + torch.rand(5, 2), + torch.rand(2, 5), + ), "rand_rand_rand_3d": lambda: ( torch.rand(2, 5, 5), torch.rand(2, 5, 2),