From 48a1b42df8726a6b01bc5178069291a3765273fc Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 4 Dec 2024 15:10:09 -0800 Subject: [PATCH] Fix pyre Differential Revision: D66787624 --- backends/arm/_passes/arm_pass_utils.py | 2 +- backends/arm/_passes/keep_dims_false_to_squeeze_pass.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 78ee6e265ca..7377d401aba 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -156,7 +156,7 @@ def get_node_arg(args: list | dict, key: int | str | type, default_value=None): f"Out of bounds index {key} for getting value in args (of size {len(args)})" ) elif isinstance(key, str): - return args.get(key, default_value) + return args.get(key, default_value) # pyre-ignore[16] elif isclass(key): for arg in args: if isinstance(arg, key): diff --git a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py index 736c627d914..f4d369a5040 100644 --- a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py +++ b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py @@ -64,12 +64,17 @@ def call(self, graph_module: torch.fx.GraphModule): continue sum_node = cast(torch.fx.Node, node) - keep_dim = get_node_arg(sum_node.args, keep_dim_index, False) + keep_dim = get_node_arg( + # pyre-ignore[6] + sum_node.args, + keep_dim_index, + False, + ) if keep_dim: continue - dim_list = get_node_arg(sum_node.args, 1, [0]) + dim_list = get_node_arg(sum_node.args, 1, [0]) # pyre-ignore[6] # Add keep_dim = True arg to sum node. set_node_arg(sum_node, 2, True)