diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index 738c5ab8204..de85dfae92f 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -96,7 +96,9 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: # De-tag outmost q-nodes upwards and dq-nodes downwards. # De-tag if at least one input/ output is not part of partition. - for node in partition.nodes: + for node in exported_program.graph_module.graph.nodes: + if not is_partitioned(node): + continue if is_quant_node(node): for input in node.all_input_nodes: if not is_partitioned(input):