Arm backend: Fix torch.matmul() failures for 2D tensor inputs #14845
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
ConvertMmToBmmPass converts an MM node to BMM nodes, turns input and output tensors from rank-2 to rank-3 via unsqueeze/squeeze, and inserts q-dq before and after BMM node when necessary.
After ConvertMmToBmmPass:
Therefore, if the original matmul was 2D, the bmm already has DQ nodes on its inputs and Q node on its output. If AnnotateDecomposedMatmulPass (Arm backend: Add support for single input matmul #10654) is still applied in this case, it produces illegal sequences such as: x -> q -> unsqueeze -> q_2 (invalid)
Fix by checking whether the BMM is already surrounded by DQ nodes on its inputs and Q nodes on its output.
Change-Id: I9949d59b0b4a96fa34a88b0734014567ea6f24cc
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218