Fix GRU w8a32 operator (#17226)#17226
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17226
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Unrelated FailuresAs of commit 5d8ce13 with merge base b5cf3c3 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This pull request fixes the GRU w8a32 operator by correcting the output shape in both the reference implementation and meta kernel, and enhancing pattern matching with safer parameter checks.
Changes:
- Fixed GRU w8a32 operator output shape from
(2, hidden_dim)to(2, batch, input_dim, hidden_dim)to properly reflect the expected dimensions - Enhanced pattern matching safety by using
.get()method instead of direct dictionary access for tensor metadata - Added SharedQuantizationSpec for GRU biases to ensure consistent quantization scales
- Added metadata propagation for transposed tensors in fusion pass
- Added input shape validation for conv operator
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/cadence/aot/tests/test_ref_implementations.py | Updated test expectations to match corrected output shape |
| backends/cadence/aot/ref_implementations.py | Fixed output shape calculation by expanding hidden state instead of flattening |
| backends/cadence/aot/quantizer/patterns.py | Added safer metadata access, input validation, and shared bias quantization spec |
| backends/cadence/aot/quantizer/fusion_pass.py | Added val metadata propagation for transposed inputs and weights |
| backends/cadence/aot/ops_registrations.py | Updated meta kernel to return correct output shape with improved documentation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| seq_len = inputs.shape[1] | ||
| assert seq_len == 1 | ||
| # inputs comes in shape [batch, seq_len, input_size] | ||
| # hidden comes in shape [batch, seq_len, hidden_size] | ||
| # weights_inputs comes in shape [3 * hidden_size, input_size] | ||
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | ||
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | ||
| # The first dimension stacks the output and the new hidden state | ||
| return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32) |
There was a problem hiding this comment.
The assertion assert seq_len == 1 will fail with the 2D inputs used in the test cases. Looking at test_ref_implementations.py, the tests use 2D inputs like torch.tensor([[1.0, 2.0]], dtype=torch.float32) with shape [1, 2], where inputs.shape[1] is 2, not 1. Either the assertion logic needs to check if inputs is 3D before asserting seq_len == 1, or the comments describing the expected shapes are incorrect and should indicate 2D inputs [batch, input_size] instead of 3D [batch, seq_len, input_size].
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| # inputs comes in shape [batch, seq_len, input_size] | |
| # hidden comes in shape [batch, seq_len, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32) | |
| # inputs may come in shape [batch, seq_len, input_size] or [batch, input_size] (for seq_len == 1) | |
| if inputs.dim() == 3: | |
| batch = inputs.shape[0] | |
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| elif inputs.dim() == 2: | |
| batch = inputs.shape[0] | |
| seq_len = 1 | |
| else: | |
| raise AssertionError(f"Unsupported inputs.dim() for quantized_w8a32_gru_meta: {inputs.dim()}") | |
| # hidden is expected to have hidden_size in its last dimension | |
| hidden_size = hidden.shape[-1] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty((2, batch, seq_len, hidden_size), dtype=torch.float32) |
| seq_len = inputs.shape[1] | ||
| assert seq_len == 1 | ||
| # inputs comes in shape [batch, seq_len, input_size] | ||
| # hidden comes in shape [batch, seq_len, hidden_size] | ||
| # weights_inputs comes in shape [3 * hidden_size, input_size] | ||
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | ||
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | ||
| # The first dimension stacks the output and the new hidden state | ||
| return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32) |
There was a problem hiding this comment.
The comments describe inputs as 3D [batch, seq_len, input_size] and hidden as 3D [batch, seq_len, hidden_size], but the test cases in test_ref_implementations.py use 2D tensors [batch, input_size] and [batch, hidden_size]. The comments should be updated to reflect the actual expected shapes, or the implementation should be modified to match the documented shapes.
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| # inputs comes in shape [batch, seq_len, input_size] | |
| # hidden comes in shape [batch, seq_len, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32) | |
| # inputs comes in shape [batch, input_size] | |
| # hidden comes in shape [batch, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| assert len(inputs.shape) == 2 | |
| assert len(hidden.shape) == 2 | |
| assert inputs.shape[0] == hidden.shape[0] | |
| return hidden.new_empty((2, inputs.shape[0], hidden.shape[-1]), dtype=torch.float32) |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
6f6ff6a to
62b059f
Compare
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
62b059f to
e0f3688
Compare
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
Summary: Pull Request resolved: pytorch#17226 # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
e0f3688 to
f4936ca
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Bail if length != kernel size - Not yet supported | ||
| if inputs_shape[-1] != cnn_weights_shape[2]: |
There was a problem hiding this comment.
In MixedW8A32ConvPattern, this new guard compares the input tensor's last dimension (which is conv1d length for NCL inputs) against the kernel size. That will reject common valid conv1d shapes (e.g., length=5, kernel=3) even though the w8a32 conv ref/meta implementations support them. If the intent is to validate shapes, this should instead check that the input channel dimension matches the weight's in_channels and that length >= kernel_size (not length == kernel_size).
| # Bail if length != kernel size - Not yet supported | |
| if inputs_shape[-1] != cnn_weights_shape[2]: | |
| # Bail if the input channels do not match the weight's in_channels | |
| # or if the input length is smaller than the kernel size. | |
| if ( | |
| len(inputs_shape) < 3 | |
| or inputs_shape[1] != cnn_weights_shape[1] | |
| or inputs_shape[-1] < cnn_weights_shape[2] | |
| ): |
| assert len(dequants_biases) == 2 | ||
| w_i_scale = dequants_weights[0].args[1] | ||
| w_h_scale = dequants_weights[1].args[1] | ||
| b_i_scale = dequants_biases[0].args[1] | ||
| b_h_scale = dequants_biases[1].args[1] | ||
| b_scale = dequants_biases[0].args[1] |
There was a problem hiding this comment.
get_args_and_kwargs_mixed_w8a32_gru still indexes dequants_weights[0] and [1], but the defensive assert len(dequants_weights) == 2 was removed. If the partition ever produces an unexpected number of dequant nodes, this will fail with an IndexError and be harder to diagnose. Please restore the assert (or otherwise validate length) before indexing.
| seq_len = inputs.shape[1] | ||
| assert seq_len == 1 | ||
| # inputs comes in shape [batch, seq_len, input_size] | ||
| # hidden comes in shape [batch, seq_len, hidden_size] | ||
| # weights_inputs comes in shape [3 * hidden_size, input_size] | ||
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | ||
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | ||
| # The first dimension stacks the output and the new hidden state | ||
| return hidden.new_empty( | ||
| (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32 |
There was a problem hiding this comment.
The meta kernel asserts seq_len = inputs.shape[1] and seq_len == 1, but the ref implementation and unit tests treat inputs as a 2D tensor shaped like [batch, input_dim] (e.g., 1x2, 1x3). With the current meta logic, any input_dim != 1 will trip this assert during fake tensor shape propagation / export. Please remove/relax this assertion and update the shape comments to match the actual operator contract (and keep the output shape computation consistent with ref_implementations.quantized_w8a32_gru).
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| # inputs comes in shape [batch, seq_len, input_size] | |
| # hidden comes in shape [batch, seq_len, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty( | |
| (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32 | |
| # inputs comes in shape [batch, input_size] | |
| # hidden comes in shape [batch, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty( | |
| (2, inputs.shape[0], hidden.shape[-1]), dtype=torch.float32 |
f4936ca to
6d04470
Compare
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
6d04470 to
f3eb5de
Compare
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
f3eb5de to
54e29f5
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| seq_len = inputs.shape[1] | ||
| assert seq_len == 1 |
There was a problem hiding this comment.
quantized_w8a32_gru_meta indexes inputs.shape[1] without first validating inputs rank, which can raise an IndexError (and produce a confusing failure) if the op is ever invoked with a 1D/2D input. Add an explicit shape/rank check (e.g., inputs.dim() == 3 / hidden.dim() == 3) before reading shape[1], and then assert seq_len == 1 with a clear message.
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| assert inputs.dim() == 3, ( | |
| "quantized_w8a32_gru expects inputs to have shape " | |
| "[batch, seq_len, input_size]" | |
| ) | |
| assert hidden.dim() == 3, ( | |
| "quantized_w8a32_gru expects hidden to have shape " | |
| "[batch, seq_len, hidden_size]" | |
| ) | |
| seq_len = inputs.shape[1] | |
| assert ( | |
| seq_len == 1 | |
| ), "quantized_w8a32_gru fake kernel only supports seq_len == 1" |
| expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]) | ||
| self.assertEqual( | ||
| output.shape, | ||
| (2, *hidden.shape), | ||
| f"Output shape should match {(2, *hidden.shape)} in {name}", | ||
| expected_shape, | ||
| f"Output shape should match {expected_shape} in {name}", |
There was a problem hiding this comment.
The updated expected shape uses inputs.shape[1] as the sequence-length dimension, but the test vectors still pass 2D inputs/hidden tensors. This means the test isn’t exercising the (documented in quantized_w8a32_gru_meta) [batch, seq_len, input_size] / [batch, seq_len, hidden_size] path and can mask shape bugs for the intended 3D case. Consider updating these fixtures (and/or adding an additional case) to use inputs and hidden with an explicit seq_len dimension (typically seq_len==1) so the test matches the operator contract.
| batch_size = inputs.shape[0] | ||
| input_dim = inputs.shape[1] | ||
| hidden_dim = hidden.shape[-1] | ||
|
|
||
| new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim) | ||
|
|
||
| return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0) | ||
|
|
||
|
|
There was a problem hiding this comment.
new_hidden is already shaped like the computed GRU output (e.g., [batch, seq_len, hidden] when inputs is 3D), but the new unsqueeze(1).expand(batch_size, input_dim, hidden_dim) path will throw for 3D inputs/hidden because it adds an extra dimension and then calls expand with too few sizes. This currently only works when new_hidden is 2D.
Consider normalizing new_hidden to [batch, seq_len, hidden] (e.g., only unsqueeze when it’s missing the seq_len dim) and then stacking directly, rather than expanding based on inputs.shape[1].
| batch_size = inputs.shape[0] | |
| input_dim = inputs.shape[1] | |
| hidden_dim = hidden.shape[-1] | |
| new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim) | |
| return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0) | |
| if new_hidden.dim() == 1: | |
| new_hidden_normalized = new_hidden.unsqueeze(0).unsqueeze(0) | |
| elif new_hidden.dim() == 2: | |
| new_hidden_normalized = new_hidden.unsqueeze(1) | |
| elif new_hidden.dim() == 3: | |
| new_hidden_normalized = new_hidden | |
| else: | |
| raise ValueError( | |
| f"Hidden state must be 1D, 2D, or 3D, got shape {tuple(new_hidden.shape)}" | |
| ) | |
| return torch.stack([new_hidden_normalized, new_hidden_normalized], dim=0) |
Summary: Pull Request resolved: pytorch#17226 # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
54e29f5 to
341bff3
Compare
Summary: Pull Request resolved: pytorch#17226 # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
341bff3 to
5b16405
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| batch_size = inputs.shape[0] | ||
| input_dim = inputs.shape[1] | ||
| hidden_dim = hidden.shape[-1] | ||
|
|
||
| new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim) |
There was a problem hiding this comment.
new_hidden is asserted to have the same shape as original_hidden_shape, and the function comment says hidden can be 3D (e.g. (1, 1, hidden_dim)). In that case new_hidden will already be 3D, so new_hidden.unsqueeze(1) becomes 4D and the subsequent expand(batch_size, input_dim, hidden_dim) (3 sizes) will raise a runtime error. Consider normalizing new_hidden to the desired 3D shape (batch, seq_len, hidden_dim) (only unsqueeze when it is 2D), or simply torch.stack([new_hidden, new_hidden], dim=0) when new_hidden already matches the output layout.
| batch_size = inputs.shape[0] | |
| input_dim = inputs.shape[1] | |
| hidden_dim = hidden.shape[-1] | |
| new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim) | |
| if new_hidden.dim() == 1: | |
| new_hidden_3d = new_hidden.view(1, 1, -1) | |
| elif new_hidden.dim() == 2: | |
| new_hidden_3d = new_hidden.unsqueeze(1) | |
| elif new_hidden.dim() == 3: | |
| new_hidden_3d = new_hidden | |
| else: | |
| raise ValueError( | |
| f"Hidden state must be 1D, 2D, or 3D, got {new_hidden.dim()}D" | |
| ) | |
| batch_size = inputs.shape[0] | |
| input_dim = inputs.shape[1] | |
| hidden_dim = hidden.shape[-1] | |
| new_hidden_expanded = new_hidden_3d.expand(batch_size, input_dim, hidden_dim) |
| expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]) | ||
| self.assertEqual( | ||
| output.shape, | ||
| (2, *hidden.shape), | ||
| f"Output shape should match {(2, *hidden.shape)} in {name}", | ||
| expected_shape, | ||
| f"Output shape should match {expected_shape} in {name}", | ||
| ) |
There was a problem hiding this comment.
The test only validates the new 4D output shape against inputs.shape[:2], but it still uses 2D inputs / 2D hidden fixtures. Since this op is produced from aten.gru fusion, it’s important to add a case with the expected real input ranks (e.g. 3D inputs and 3D hidden/state) so the output-shape logic is exercised for those ranks as well (this would also catch shape errors like an unsqueeze/expand mismatch).
| # Using SharedQuantizationSpec so that bias_hh has the same observer as bias_ih | ||
| # Both biases get the same quantization scale to match the cpp operator | ||
| bias_ih_node = wrapper.args[2] | ||
| bias_ih_edge = (bias_ih_node, gru_layer) | ||
| shared_bias_qspec = SharedQuantizationSpec(edge_or_node=bias_ih_edge) | ||
|
|
There was a problem hiding this comment.
This code assumes the GRU params tuple contains both biases and immediately indexes wrapper.args[2]/[3]. For aten.gru.input, has_biases can be false (e.g. nn.GRU(..., bias=False)), in which case the params list may not contain bias entries and this will raise an IndexError during pattern matching. Consider explicitly checking the has_biases argument (and/or len(wrapper.args) >= 4) and returning empty=True anchors when biases are not present, since the Cadence quantized_w8a32_gru replacement requires bias tensors.
Summary: Pull Request resolved: pytorch#16607 #### Summary This diff fixes the Conv1d w8a32 operator by adding a transformation to the `val` attribute of the `other_inputs[0].meta` dictionary. Specifically, the `permute` operation is applied to the `original_val` tensor with the `fake_mode` context, and the resulting `transposed_val` is assigned to `transposed_inputs.meta["val"]`. Differential Revision: D89863750 Reviewed By: mcremon-meta
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
5b16405 to
8e98f88
Compare
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
8e98f88 to
76ec628
Compare
Summary: Pull Request resolved: pytorch#17226 # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
76ec628 to
5d8ce13
Compare
Summary:
Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.
Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.
Reviewed By: hsharma35
Differential Revision: D90437262