diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index d8db866fa4e..e4be6a09641 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -55,7 +55,6 @@ def _validate_ref_impl_exists() -> None: _WARN_ONLY = { "cadence::quantized_softmax.per_tensor", "cadence::quantized_softmax", - "cadence::quantized_w8a32_gru", } ref_impls = get_registered_ref_implementations() @@ -2753,7 +2752,7 @@ def quantized_w8a32_gru_meta( bias_hidden: torch.Tensor, b_h_scale: float, ) -> torch.Tensor: - return inputs.new_empty((2, hidden.shape[-1]), dtype=inputs.dtype) + return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32) # Validate that all meta kernels have reference implementations diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 6e0c116ad45..3e08cdc358c 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -985,6 +985,70 @@ def quantized_w8a32_linear( return output +@impl_tracked(m, "quantized_w8a32_gru") +def quantized_w8a32_gru( + inputs: torch.Tensor, + hidden: torch.Tensor, + weights_inputs: torch.Tensor, + w_i_scale: float, + weights_hidden: torch.Tensor, + w_h_scale: float, + bias_inputs: torch.Tensor, + b_i_scale: float, + bias_hidden: torch.Tensor, + b_h_scale: float, +) -> torch.Tensor: + assert weights_inputs.dtype == torch.int8 + assert weights_hidden.dtype == torch.int8 + assert bias_inputs.dtype == torch.int8 + assert bias_hidden.dtype == torch.int8 + assert inputs.dtype == torch.float32 + assert hidden.dtype == torch.float32 + + if len(hidden.shape) > 2: + raise ValueError("Hidden state must be 2D or 1D") + + if len(hidden.shape) == 2 and hidden.shape[0] != 1: + raise ValueError("Leading dimension of hidden state must be 1") + + original_hidden_shape = hidden.shape + hidden = hidden.view(-1) + + hidden_dim = hidden.shape[0] + if (hidden_dim % 4) != 0: + raise ValueError( + "Hidden dimension must be a multiple of 4 for HiFi SIMD operations" + ) + + dequant_weights_inputs = weights_inputs.float() * w_i_scale + dequant_weights_hidden = weights_hidden.float() * w_h_scale + + # C++ implementation averages the two bias scales + avg_bias_scale = (b_i_scale + b_h_scale) / 2 + dequant_bias_inputs = bias_inputs.float() * avg_bias_scale + dequant_bias_hidden = bias_hidden.float() * avg_bias_scale + + gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs) + gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden) + + i_r, i_z, i_n = gi.chunk(3, -1) + h_r, h_z, h_n = gh.chunk(3, -1) + + reset_gate = torch.sigmoid(i_r + h_r) + update_gate = torch.sigmoid(i_z + h_z) + new_gate = torch.tanh(i_n + reset_gate * h_n) + + new_hidden = (1 - update_gate) * new_gate + update_gate * hidden + + if new_hidden.shape[0] != 1: + raise ValueError("Leading dimension of hidden state must be 1") + + assert new_hidden.shape == original_hidden_shape + + new_hidden = new_hidden.view(-1) + return torch.stack([new_hidden, new_hidden], dim=0) + + @impl_tracked(m, "quantized_conv2d_nhwc.per_tensor") def quantized_conv2d_nhwc_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index e9ba52c58b9..c38668b76c6 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2894,3 +2894,188 @@ def test_softmax_f32_f32(self) -> None: output = torch.ops.cadence._softmax_f32_f32(input_tensor, dim=1) self.assertEqual(output.dtype, torch.float32) self.assertEqual(output.shape, input_tensor.shape) + + @expand( + [ + ( + "basic_hidden_dim_4", + torch.tensor([[1.0, 2.0]], dtype=torch.float32), # inputs: 1x2 + torch.tensor( + [[0.5, 0.5, 0.5, 0.5]], dtype=torch.float32 + ), # hidden: 1x4 + torch.ones( + (12, 2), dtype=torch.int8 + ), # weights_inputs: 12x2 (3*4 x input_dim=2) + 0.1, # w_i_scale + torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 (3*4 x 4) + 0.1, # w_h_scale + torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 + 0.1, # b_i_scale + torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 + 0.1, # b_h_scale + ), + ( + "invalid_batch_size_2", + torch.tensor( + [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], dtype=torch.float32 + ), # inputs: 2x3 + torch.tensor( + [[0.5, 0.5, 0.5, 0.5], [0.3, 0.3, 0.3, 0.3]], dtype=torch.float32 + ), # hidden: 2x4 + torch.ones((12, 3), dtype=torch.int8), # weights_inputs: 12x3 + 0.1, # w_i_scale + torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 + 0.1, # w_h_scale + torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 + 0.1, # b_i_scale + torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 + 0.1, # b_h_scale + ), + ( + "non_zero_biases", + torch.tensor([[1.0, 1.0]], dtype=torch.float32), # inputs: 1x2 + torch.zeros((1, 4), dtype=torch.float32), # hidden: 1x4 + torch.ones((12, 2), dtype=torch.int8), # weights_inputs: 12x2 + 0.2, # w_i_scale + torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 + 0.1, # w_h_scale + torch.tensor( + [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8 + ), # bias_inputs: 12 + 0.1, # b_i_scale + torch.tensor( + [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8 + ), # bias_hidden: 12 + 0.1, # b_h_scale + ), + ( + "negative_weights", + torch.tensor([[1.0, -1.0]], dtype=torch.float32), # inputs: 1x2 + torch.tensor( + [[0.5, -0.5, 0.5, -0.5]], dtype=torch.float32 + ), # hidden: 1x4 + torch.tensor( + [[1, -1], [-1, 1]] * 6, dtype=torch.int8 + ), # weights_inputs: 12x2 (alternating pattern) + 0.1, # w_i_scale + torch.tensor( + [[1, -1, 1, -1], [-1, 1, -1, 1]] * 6, dtype=torch.int8 + ), # weights_hidden: 12x4 (alternating pattern) + 0.1, # w_h_scale + torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 + 0.1, # b_i_scale + torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 + 0.1, # b_h_scale + ), + ( + "hidden_dim_8", + torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # inputs: 1x3 + torch.tensor( + [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 + ), # hidden: 1x8 + torch.ones((24, 3), dtype=torch.int8), # weights_inputs: 24x3 (3*8 x 3) + 0.1, # w_i_scale + torch.ones((24, 8), dtype=torch.int8), # weights_hidden: 24x8 (3*8 x 8) + 0.1, # w_h_scale + torch.zeros(24, dtype=torch.int8), # bias_inputs: 24 + 0.1, # b_i_scale + torch.zeros(24, dtype=torch.int8), # bias_hidden: 24 + 0.1, # b_h_scale + ), + ] + ) + def test_quantized_w8a32_gru( + self, + name: str, + inputs: torch.Tensor, + hidden: torch.Tensor, + weights_inputs: torch.Tensor, + w_i_scale: float, + weights_hidden: torch.Tensor, + w_h_scale: float, + bias_inputs: torch.Tensor, + b_i_scale: float, + bias_hidden: torch.Tensor, + b_h_scale: float, + ) -> None: + + if name == "invalid_batch_size_2": + with self.assertRaises(ValueError) as context: + torch.ops.cadence.quantized_w8a32_gru( + inputs, + hidden, + weights_inputs, + w_i_scale, + weights_hidden, + w_h_scale, + bias_inputs, + b_i_scale, + bias_hidden, + b_h_scale, + ) + self.assertIn( + "Leading dimension of hidden state must be 1", str(context.exception) + ) + return + + output = torch.ops.cadence.quantized_w8a32_gru( + inputs, + hidden, + weights_inputs, + w_i_scale, + weights_hidden, + w_h_scale, + bias_inputs, + b_i_scale, + bias_hidden, + b_h_scale, + ) + + # Verify output properties + self.assertEqual( + output.dtype, + torch.float32, + f"Output dtype should be float32 in {name}", + ) + self.assertEqual( + output.shape, + (2, hidden.shape[-1]), + f"Output shape should match {(2, hidden.shape[-1])} in {name}", + ) + assert isinstance(output, torch.Tensor) + + # Verify output is bounded: GRU hidden state is a convex combination of + # tanh([-1,1]) and previous hidden([-1,1]), so output should be in [-1,1] + self.assertTrue( + torch.all(output >= -1.0) and torch.all(output <= 1.0), + f"Output values should be in [-1.1, 1.1] in {name}. Got min={output.min():.4f}, max={output.max():.4f}", + ) + + def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None: + # Test that non-multiple of 4 hidden dimension raises error + inputs = torch.tensor([[1.0, 2.0]], dtype=torch.float32) # 1x2 + hidden = torch.tensor( + [[0.5, 0.5, 0.5]], dtype=torch.float32 + ) # 1x3 (not divisible by 4) + weights_inputs = torch.zeros((9, 2), dtype=torch.int8) # 9x2 + weights_hidden = torch.zeros((9, 3), dtype=torch.int8) # 9x3 + bias_inputs = torch.zeros(9, dtype=torch.int8) + bias_hidden = torch.zeros(9, dtype=torch.int8) + + with self.assertRaises(ValueError) as context: + torch.ops.cadence.quantized_w8a32_gru( + inputs, + hidden, + weights_inputs, + 0.1, + weights_hidden, + 0.1, + bias_inputs, + 0.1, + bias_hidden, + 0.1, + ) + + self.assertIn( + "Hidden dimension must be a multiple of 4", str(context.exception) + )