diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 00400403983..f6bf26beacb 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1701,31 +1701,38 @@ def rope( input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2], -1 ) - _, s, h, hd = input_tensor.shape + _, seq, _, hd = input_tensor.shape if hd % 2: raise ValueError("Hidden dimension must be divisible by 2") - if sin_tensor.shape != (s, hd // 2) or cos_tensor.shape != (s, hd // 2): + if ( + sin_tensor.size(-1) * 2 != hd + or cos_tensor.size(-1) * 2 != hd + or sin_tensor.size(0) < seq + or cos_tensor.size(0) < seq + ): raise ValueError( - f"sin_tensor and cos_tensor must have shape {s, hd // 2}. Got {sin_tensor.shape} and {cos_tensor.shape}" + f"sin_tensor and cos_tensor must have shape {seq}) x {hd // 2}>. Got {sin_tensor.shape} and {cos_tensor.shape}" ) if pos is not None: - if pos.shape != (input_tensor.shape[1],): + if pos.shape != (seq,): raise ValueError( f"pos must have shape {input_tensor.shape[1]}. Got {pos.shape}" ) sin_tensor = sin_tensor[pos] cos_tensor = cos_tensor[pos] + # seq x 1 x hd sin_tensor = sin_tensor.unsqueeze(1) cos_tensor = cos_tensor.unsqueeze(1) + # batch x seq x num_heads x head_dim_by_two x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2] - rotated = torch.cat( - [x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1 - ) + o0 = x0 * cos_tensor - x1 * sin_tensor + o1 = x0 * sin_tensor + x1 * cos_tensor + rotated = torch.cat([o0.view(-1, 1), o1.view(-1, 1)], dim=-1) return rotated.view(original_shape) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 24bbe7ee644..13519332511 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1458,7 +1458,7 @@ def test_where_Scalar(self) -> None: torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32), torch.tensor([[0.0, 0.0]], dtype=torch.float32), torch.tensor([[1.0, 1.0]], dtype=torch.float32), - torch.tensor([[[[1.0, 3.0, 2.0, 4.0]]]], dtype=torch.float32), + torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32), ), ( "h2xhd4", @@ -1469,7 +1469,7 @@ def test_where_Scalar(self) -> None: torch.tensor([[0.0, 1.0]], dtype=torch.float32), torch.tensor([[1.0, 0.0]], dtype=torch.float32), torch.tensor( - [[[[1.0, -4.0, 2.0, 3.0], [5, -8.0, 6.0, 7.0]]]], + [[[[1.0, 2.0, -4.0, 3.0], [5, 6.0, -8.0, 7.0]]]], dtype=torch.float32, ), ), @@ -1489,8 +1489,8 @@ def test_where_Scalar(self) -> None: torch.tensor( [ [ - [[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]], - [[9.0, -12.0, 10.0, 11.0], [13.0, -16.0, 14.0, 15.0]], + [[1.0, 2.0, -4.0, 3.0], [5.0, 6.0, -8.0, 7.0]], + [[9.0, 10.0, -12.0, 11.0], [13.0, 14.0, -16.0, 15.0]], ] ], dtype=torch.float32, @@ -1512,8 +1512,8 @@ def test_where_Scalar(self) -> None: torch.tensor( [ [ - [[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]], - [[-10.0, 11.0, 9.0, 12.0], [-14.0, 15.0, 13.0, 16.0]], + [[1.0, 2.0, -4.0, 3.0], [5.0, 6.0, -8.0, 7.0]], + [[-10.0, 9.0, 11.0, 12.0], [-14.0, 13.0, 15.0, 16.0]], ] ], dtype=torch.float32, diff --git a/backends/cadence/generic/operators/op_rope.cpp b/backends/cadence/generic/operators/op_rope.cpp index f47e9752125..4a392bed1ee 100644 --- a/backends/cadence/generic/operators/op_rope.cpp +++ b/backends/cadence/generic/operators/op_rope.cpp @@ -23,17 +23,20 @@ Tensor& rope_out( const optional& pos, Tensor& out) { // Input shape is [1, seq, h, hd / 2, 2] or [1, seq, h, hd] - const auto kSeq = input.size(1); - const auto kH = input.size(2); - const auto kHd = input.numel() / (kSeq * kH); - for (int32_t s = 0; s < kSeq; ++s) { - for (int32_t h = 0; h < kH; ++h) { - for (int32_t hd_o = 0; hd_o < kHd / 2; ++hd_o) { - float x_0 = - input.const_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2]; - float x_1 = - input - .const_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2 + 1]; + const ssize_t seq_length = input.size(1); + const ssize_t num_heads = input.size(2); + const ssize_t head_dimension = input.numel() / (seq_length * num_heads); + const ssize_t head_dimension_by_two = head_dimension / 2; + for (int32_t s = 0; s < seq_length; ++s) { + for (int32_t h = 0; h < num_heads; ++h) { + for (int32_t hd_o = 0; hd_o < head_dimension_by_two; ++hd_o) { + // Process 2 elements in head dimension at a time. + const float x_0 = input.const_data_ptr() + [s * num_heads * head_dimension + + h * head_dimension + hd_o * 2]; + const float x_1 = input.const_data_ptr() + [s * num_heads * head_dimension + + h * head_dimension + hd_o * 2 + 1]; int64_t token_id = s; if (pos.has_value()) { if (pos->scalar_type() == ::executorch::aten::ScalarType::Int) { @@ -42,17 +45,21 @@ Tensor& rope_out( token_id = pos.has_value() ? pos->const_data_ptr()[s] : s; } } - float sin = - sin_tensor.const_data_ptr()[token_id * kHd / 2 + hd_o]; - float cos = - cos_tensor.const_data_ptr()[token_id * kHd / 2 + hd_o]; - float out_0 = x_0 * cos - x_1 * sin; - float out_1 = x_0 * sin + x_1 * cos; - out.mutable_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2] = + const float sin = sin_tensor.const_data_ptr< + float>()[token_id * head_dimension_by_two + hd_o]; + const float cos = cos_tensor.const_data_ptr< + float>()[token_id * head_dimension_by_two + hd_o]; + + const float out_0 = x_0 * cos - x_1 * sin; + out.mutable_data_ptr() + [s * num_heads * head_dimension + h * head_dimension + hd_o * 2] = out_0; - out.mutable_data_ptr()[s * kH * kHd + h * kHd + hd_o * 2 + 1] = - out_1; + + const float out_1 = x_0 * sin + x_1 * cos; + out.mutable_data_ptr() + [s * num_heads * head_dimension + h * head_dimension + hd_o * 2 + + 1] = out_1; } } }