diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index 8f740a7bf4c..eda9dd28bf9 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -98,15 +98,20 @@ def call_operator(self, op, args, kwargs, meta, updated=False): avg_pool2d_op, pool_args, kwargs, meta, True ) row.append(pooled) - - # Concatenate row results along width (dim=3) - row_tensor = super().call_operator( - cat_op, (row, 3), kwargs, meta_with_no_qparams, True - ) + # Concatenate row results along width (dim=3) if more than one. + if len(row) > 1: + row_tensor = super().call_operator( + cat_op, (row, 3), kwargs, meta_with_no_qparams, True + ) + else: + row_tensor = row[0] res.append(row_tensor) - # Concatenate all rows along height (dim=2) - out = super().call_operator( - cat_op, (res, 2), kwargs, meta_with_no_qparams, True - ) + # Concatenate all rows along height (dim=2) if more than one. + if len(res) > 1: + out = super().call_operator( + cat_op, (res, 2), kwargs, meta_with_no_qparams, True + ) + else: + out = res[0] return out