From 24483d22a6925d8e2f821da352cd5e4d9e1f0d9f Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Thu, 6 Feb 2025 10:27:20 -0800 Subject: [PATCH] [ET-VK] fix index error bug in ViewCopyToSqueezeUnsqueezePass See T214560872 https://github.com/pytorch/executorch/pull/8226 added the pass to the partition preprocess pass list, so now it runs on all exports. This uncovered a bug in the squeeze dims finding function in the mobilenet test case. Differential Revision: [D69254910](https://our.internmc.facebook.com/intern/diff/D69254910/) [ghstack-poisoned] --- .../view_copy_to_squeeze_unsqueeze.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/backends/transforms/view_copy_to_squeeze_unsqueeze.py b/backends/transforms/view_copy_to_squeeze_unsqueeze.py index 094ec6a3340..eccb62f0d49 100644 --- a/backends/transforms/view_copy_to_squeeze_unsqueeze.py +++ b/backends/transforms/view_copy_to_squeeze_unsqueeze.py @@ -47,15 +47,18 @@ def find_squeeze_dims( j = 0 idx = [] while i < len(input_shape): - if input_shape[i] != view_shape[j]: - if input_shape[i] == 1: - idx.append(i) - j -= 1 - # continue to check remaining dims are equal - else: - return None - i += 1 - j += 1 + if j < len(view_shape) and input_shape[i] == view_shape[j]: + i += 1 + j += 1 + elif input_shape[i] == 1: + # squeeze axis on i and check next dim + idx.append(i) + i += 1 + else: + return None + # If there are remaining dimensions in view_shape, shapes do not match + if j < len(view_shape): + return None return idx def find_unsqueeze_dim(