diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index e3e0a873020..4d4c0ee75b1 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -94,6 +94,8 @@ def call_operator(self, op, args, kwargs, meta): input_shape = list(x.data.shape) output_shape = list(meta["val"].shape) dims_to_reduce = get_node_arg(args, 1) + if dims_to_reduce is None: + dims_to_reduce = range(len(input_shape)) dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 96ec7793551..656f35fb17f 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -115,7 +115,7 @@ class MeanDim(torch.nn.Module): test_data_suite: dict[str, tuple] = { "rank_1_keepdim": lambda: ( torch.rand(7), - (0), + 0, True, ), "rank_2_keepdim": lambda: ( @@ -168,6 +168,11 @@ class MeanDim(torch.nn.Module): (0, 1, 2, 3), True, ), + "rand_none_keepdim": lambda: ( + torch.rand(1, 5, 7, 3), + None, + True, + ), "rank_1": lambda: ( torch.rand(7), (-1), @@ -280,7 +285,6 @@ def test_mean_dim_tosa_INT(test_data): (test_data,), [], # Might be sum, avgpool, or both symmetric_io_quantization=True, - custom_path="MEANDIM", ) pipeline.run()