Skip to content

Fix GRU w8a32 operator (#17226)#17226

Merged
meta-codesync[bot] merged 2 commits intopytorch:mainfrom
mgiordy:export-D90437262
Apr 15, 2026
Merged

Fix GRU w8a32 operator (#17226)#17226
meta-codesync[bot] merged 2 commits intopytorch:mainfrom
mgiordy:export-D90437262

Conversation

@mgiordy
Copy link
Copy Markdown
Contributor

@mgiordy mgiordy commented Feb 4, 2026

Summary:

Context

This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

Mitigation

The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262

Copilot AI review requested due to automatic review settings February 4, 2026 23:19
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 4, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17226

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 2 Unrelated Failures

As of commit 5d8ce13 with merge base b5cf3c3 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 4, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Feb 4, 2026

@mgiordy has exported this pull request. If you are a Meta employee, you can view the originating Diff in D90437262.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Feb 4, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request fixes the GRU w8a32 operator by correcting the output shape in both the reference implementation and meta kernel, and enhancing pattern matching with safer parameter checks.

Changes:

  • Fixed GRU w8a32 operator output shape from (2, hidden_dim) to (2, batch, input_dim, hidden_dim) to properly reflect the expected dimensions
  • Enhanced pattern matching safety by using .get() method instead of direct dictionary access for tensor metadata
  • Added SharedQuantizationSpec for GRU biases to ensure consistent quantization scales
  • Added metadata propagation for transposed tensors in fusion pass
  • Added input shape validation for conv operator

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
backends/cadence/aot/tests/test_ref_implementations.py Updated test expectations to match corrected output shape
backends/cadence/aot/ref_implementations.py Fixed output shape calculation by expanding hidden state instead of flattening
backends/cadence/aot/quantizer/patterns.py Added safer metadata access, input validation, and shared bias quantization spec
backends/cadence/aot/quantizer/fusion_pass.py Added val metadata propagation for transposed inputs and weights
backends/cadence/aot/ops_registrations.py Updated meta kernel to return correct output shape with improved documentation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +2857 to +2865
seq_len = inputs.shape[1]
assert seq_len == 1
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32)
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion assert seq_len == 1 will fail with the 2D inputs used in the test cases. Looking at test_ref_implementations.py, the tests use 2D inputs like torch.tensor([[1.0, 2.0]], dtype=torch.float32) with shape [1, 2], where inputs.shape[1] is 2, not 1. Either the assertion logic needs to check if inputs is 3D before asserting seq_len == 1, or the comments describing the expected shapes are incorrect and should indicate 2D inputs [batch, input_size] instead of 3D [batch, seq_len, input_size].

Suggested change
seq_len = inputs.shape[1]
assert seq_len == 1
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32)
# inputs may come in shape [batch, seq_len, input_size] or [batch, input_size] (for seq_len == 1)
if inputs.dim() == 3:
batch = inputs.shape[0]
seq_len = inputs.shape[1]
assert seq_len == 1
elif inputs.dim() == 2:
batch = inputs.shape[0]
seq_len = 1
else:
raise AssertionError(f"Unsupported inputs.dim() for quantized_w8a32_gru_meta: {inputs.dim()}")
# hidden is expected to have hidden_size in its last dimension
hidden_size = hidden.shape[-1]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty((2, batch, seq_len, hidden_size), dtype=torch.float32)

Copilot uses AI. Check for mistakes.
Comment on lines +2857 to +2865
seq_len = inputs.shape[1]
assert seq_len == 1
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32)
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments describe inputs as 3D [batch, seq_len, input_size] and hidden as 3D [batch, seq_len, hidden_size], but the test cases in test_ref_implementations.py use 2D tensors [batch, input_size] and [batch, hidden_size]. The comments should be updated to reflect the actual expected shapes, or the implementation should be modified to match the documented shapes.

Suggested change
seq_len = inputs.shape[1]
assert seq_len == 1
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32)
# inputs comes in shape [batch, input_size]
# hidden comes in shape [batch, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, hidden_size]
# The first dimension stacks the output and the new hidden state
assert len(inputs.shape) == 2
assert len(hidden.shape) == 2
assert inputs.shape[0] == hidden.shape[0]
return hidden.new_empty((2, inputs.shape[0], hidden.shape[-1]), dtype=torch.float32)

Copilot uses AI. Check for mistakes.
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 6, 2026

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions Bot added the stale PRs inactive for over 60 days label Apr 6, 2026
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
@meta-codesync meta-codesync Bot changed the title Fix GRU w8a32 operator Fix GRU w8a32 operator (#17226) Apr 14, 2026
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
Copilot AI review requested due to automatic review settings April 14, 2026 22:06
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:
Pull Request resolved: pytorch#17226

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +750 to +751
# Bail if length != kernel size - Not yet supported
if inputs_shape[-1] != cnn_weights_shape[2]:
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MixedW8A32ConvPattern, this new guard compares the input tensor's last dimension (which is conv1d length for NCL inputs) against the kernel size. That will reject common valid conv1d shapes (e.g., length=5, kernel=3) even though the w8a32 conv ref/meta implementations support them. If the intent is to validate shapes, this should instead check that the input channel dimension matches the weight's in_channels and that length >= kernel_size (not length == kernel_size).

Suggested change
# Bail if length != kernel size - Not yet supported
if inputs_shape[-1] != cnn_weights_shape[2]:
# Bail if the input channels do not match the weight's in_channels
# or if the input length is smaller than the kernel size.
if (
len(inputs_shape) < 3
or inputs_shape[1] != cnn_weights_shape[1]
or inputs_shape[-1] < cnn_weights_shape[2]
):

Copilot uses AI. Check for mistakes.
Comment on lines 524 to +527
assert len(dequants_biases) == 2
w_i_scale = dequants_weights[0].args[1]
w_h_scale = dequants_weights[1].args[1]
b_i_scale = dequants_biases[0].args[1]
b_h_scale = dequants_biases[1].args[1]
b_scale = dequants_biases[0].args[1]
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_args_and_kwargs_mixed_w8a32_gru still indexes dequants_weights[0] and [1], but the defensive assert len(dequants_weights) == 2 was removed. If the partition ever produces an unexpected number of dequant nodes, this will fail with an IndexError and be harder to diagnose. Please restore the assert (or otherwise validate length) before indexing.

Copilot uses AI. Check for mistakes.
Comment on lines +3066 to +3075
seq_len = inputs.shape[1]
assert seq_len == 1
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty(
(2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The meta kernel asserts seq_len = inputs.shape[1] and seq_len == 1, but the ref implementation and unit tests treat inputs as a 2D tensor shaped like [batch, input_dim] (e.g., 1x2, 1x3). With the current meta logic, any input_dim != 1 will trip this assert during fake tensor shape propagation / export. Please remove/relax this assertion and update the shape comments to match the actual operator contract (and keep the output shape computation consistent with ref_implementations.quantized_w8a32_gru).

Suggested change
seq_len = inputs.shape[1]
assert seq_len == 1
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty(
(2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32
# inputs comes in shape [batch, input_size]
# hidden comes in shape [batch, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty(
(2, inputs.shape[0], hidden.shape[-1]), dtype=torch.float32

Copilot uses AI. Check for mistakes.
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
Copilot AI review requested due to automatic review settings April 14, 2026 22:46
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +3066 to +3067
seq_len = inputs.shape[1]
assert seq_len == 1
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantized_w8a32_gru_meta indexes inputs.shape[1] without first validating inputs rank, which can raise an IndexError (and produce a confusing failure) if the op is ever invoked with a 1D/2D input. Add an explicit shape/rank check (e.g., inputs.dim() == 3 / hidden.dim() == 3) before reading shape[1], and then assert seq_len == 1 with a clear message.

Suggested change
seq_len = inputs.shape[1]
assert seq_len == 1
assert inputs.dim() == 3, (
"quantized_w8a32_gru expects inputs to have shape "
"[batch, seq_len, input_size]"
)
assert hidden.dim() == 3, (
"quantized_w8a32_gru expects hidden to have shape "
"[batch, seq_len, hidden_size]"
)
seq_len = inputs.shape[1]
assert (
seq_len == 1
), "quantized_w8a32_gru fake kernel only supports seq_len == 1"

Copilot uses AI. Check for mistakes.
Comment on lines +3023 to +3027
expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1])
self.assertEqual(
output.shape,
(2, *hidden.shape),
f"Output shape should match {(2, *hidden.shape)} in {name}",
expected_shape,
f"Output shape should match {expected_shape} in {name}",
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated expected shape uses inputs.shape[1] as the sequence-length dimension, but the test vectors still pass 2D inputs/hidden tensors. This means the test isn’t exercising the (documented in quantized_w8a32_gru_meta) [batch, seq_len, input_size] / [batch, seq_len, hidden_size] path and can mask shape bugs for the intended 3D case. Consider updating these fixtures (and/or adding an additional case) to use inputs and hidden with an explicit seq_len dimension (typically seq_len==1) so the test matches the operator contract.

Copilot uses AI. Check for mistakes.
Comment on lines +1310 to 1318
batch_size = inputs.shape[0]
input_dim = inputs.shape[1]
hidden_dim = hidden.shape[-1]

new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim)

return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0)


Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_hidden is already shaped like the computed GRU output (e.g., [batch, seq_len, hidden] when inputs is 3D), but the new unsqueeze(1).expand(batch_size, input_dim, hidden_dim) path will throw for 3D inputs/hidden because it adds an extra dimension and then calls expand with too few sizes. This currently only works when new_hidden is 2D.

Consider normalizing new_hidden to [batch, seq_len, hidden] (e.g., only unsqueeze when it’s missing the seq_len dim) and then stacking directly, rather than expanding based on inputs.shape[1].

Suggested change
batch_size = inputs.shape[0]
input_dim = inputs.shape[1]
hidden_dim = hidden.shape[-1]
new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim)
return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0)
if new_hidden.dim() == 1:
new_hidden_normalized = new_hidden.unsqueeze(0).unsqueeze(0)
elif new_hidden.dim() == 2:
new_hidden_normalized = new_hidden.unsqueeze(1)
elif new_hidden.dim() == 3:
new_hidden_normalized = new_hidden
else:
raise ValueError(
f"Hidden state must be 1D, 2D, or 3D, got shape {tuple(new_hidden.shape)}"
)
return torch.stack([new_hidden_normalized, new_hidden_normalized], dim=0)

Copilot uses AI. Check for mistakes.
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:
Pull Request resolved: pytorch#17226

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 14, 2026
Summary:
Pull Request resolved: pytorch#17226

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
Copilot AI review requested due to automatic review settings April 14, 2026 23:03
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1310 to +1314
batch_size = inputs.shape[0]
input_dim = inputs.shape[1]
hidden_dim = hidden.shape[-1]

new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim)
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_hidden is asserted to have the same shape as original_hidden_shape, and the function comment says hidden can be 3D (e.g. (1, 1, hidden_dim)). In that case new_hidden will already be 3D, so new_hidden.unsqueeze(1) becomes 4D and the subsequent expand(batch_size, input_dim, hidden_dim) (3 sizes) will raise a runtime error. Consider normalizing new_hidden to the desired 3D shape (batch, seq_len, hidden_dim) (only unsqueeze when it is 2D), or simply torch.stack([new_hidden, new_hidden], dim=0) when new_hidden already matches the output layout.

Suggested change
batch_size = inputs.shape[0]
input_dim = inputs.shape[1]
hidden_dim = hidden.shape[-1]
new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim)
if new_hidden.dim() == 1:
new_hidden_3d = new_hidden.view(1, 1, -1)
elif new_hidden.dim() == 2:
new_hidden_3d = new_hidden.unsqueeze(1)
elif new_hidden.dim() == 3:
new_hidden_3d = new_hidden
else:
raise ValueError(
f"Hidden state must be 1D, 2D, or 3D, got {new_hidden.dim()}D"
)
batch_size = inputs.shape[0]
input_dim = inputs.shape[1]
hidden_dim = hidden.shape[-1]
new_hidden_expanded = new_hidden_3d.expand(batch_size, input_dim, hidden_dim)

Copilot uses AI. Check for mistakes.
Comment on lines +3023 to 3028
expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1])
self.assertEqual(
output.shape,
(2, *hidden.shape),
f"Output shape should match {(2, *hidden.shape)} in {name}",
expected_shape,
f"Output shape should match {expected_shape} in {name}",
)
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test only validates the new 4D output shape against inputs.shape[:2], but it still uses 2D inputs / 2D hidden fixtures. Since this op is produced from aten.gru fusion, it’s important to add a case with the expected real input ranks (e.g. 3D inputs and 3D hidden/state) so the output-shape logic is exercised for those ranks as well (this would also catch shape errors like an unsqueeze/expand mismatch).

Copilot uses AI. Check for mistakes.
Comment on lines +816 to +821
# Using SharedQuantizationSpec so that bias_hh has the same observer as bias_ih
# Both biases get the same quantization scale to match the cpp operator
bias_ih_node = wrapper.args[2]
bias_ih_edge = (bias_ih_node, gru_layer)
shared_bias_qspec = SharedQuantizationSpec(edge_or_node=bias_ih_edge)

Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code assumes the GRU params tuple contains both biases and immediately indexes wrapper.args[2]/[3]. For aten.gru.input, has_biases can be false (e.g. nn.GRU(..., bias=False)), in which case the params list may not contain bias entries and this will raise an IndexError during pattern matching. Consider explicitly checking the has_biases argument (and/or len(wrapper.args) >= 4) and returning empty=True anchors when biases are not present, since the Cadence quantized_w8a32_gru replacement requires bias tensors.

Copilot uses AI. Check for mistakes.
Summary:
Pull Request resolved: pytorch#16607

#### Summary

This diff fixes the Conv1d w8a32 operator by adding a transformation to the `val` attribute of the `other_inputs[0].meta` dictionary. Specifically, the `permute` operation is applied to the `original_val` tensor with the `fake_mode` context, and the resulting `transposed_val` is assigned to `transposed_inputs.meta["val"]`.

Differential Revision: D89863750

Reviewed By: mcremon-meta
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 15, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
mgiordy pushed a commit to mgiordy/executorch that referenced this pull request Apr 15, 2026
Summary:

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
Summary:
Pull Request resolved: pytorch#17226

# Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.

# Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.

Reviewed By: hsharma35

Differential Revision: D90437262
@meta-codesync meta-codesync Bot merged commit f1209c5 into pytorch:main Apr 15, 2026
158 of 165 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported stale PRs inactive for over 60 days

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants