Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,31 +1701,38 @@ def rope(
input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2], -1
)

_, s, h, hd = input_tensor.shape
_, seq, _, hd = input_tensor.shape

if hd % 2:
raise ValueError("Hidden dimension must be divisible by 2")

if sin_tensor.shape != (s, hd // 2) or cos_tensor.shape != (s, hd // 2):
if (
sin_tensor.size(-1) * 2 != hd
or cos_tensor.size(-1) * 2 != hd
or sin_tensor.size(0) < seq
or cos_tensor.size(0) < seq
):
raise ValueError(
f"sin_tensor and cos_tensor must have shape {s, hd // 2}. Got {sin_tensor.shape} and {cos_tensor.shape}"
f"sin_tensor and cos_tensor must have shape <kvseq (> {seq}) x {hd // 2}>. Got {sin_tensor.shape} and {cos_tensor.shape}"
)

if pos is not None:
if pos.shape != (input_tensor.shape[1],):
if pos.shape != (seq,):
raise ValueError(
f"pos must have shape {input_tensor.shape[1]}. Got {pos.shape}"
)
sin_tensor = sin_tensor[pos]
cos_tensor = cos_tensor[pos]

# seq x 1 x hd
sin_tensor = sin_tensor.unsqueeze(1)
cos_tensor = cos_tensor.unsqueeze(1)

# batch x seq x num_heads x head_dim_by_two
x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2]
rotated = torch.cat(
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
)
o0 = x0 * cos_tensor - x1 * sin_tensor
o1 = x0 * sin_tensor + x1 * cos_tensor
rotated = torch.cat([o0.view(-1, 1), o1.view(-1, 1)], dim=-1)
return rotated.view(original_shape)


Expand Down
12 changes: 6 additions & 6 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,7 @@ def test_where_Scalar(self) -> None:
torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32),
torch.tensor([[0.0, 0.0]], dtype=torch.float32),
torch.tensor([[1.0, 1.0]], dtype=torch.float32),
torch.tensor([[[[1.0, 3.0, 2.0, 4.0]]]], dtype=torch.float32),
torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32),
),
(
"h2xhd4",
Expand All @@ -1469,7 +1469,7 @@ def test_where_Scalar(self) -> None:
torch.tensor([[0.0, 1.0]], dtype=torch.float32),
torch.tensor([[1.0, 0.0]], dtype=torch.float32),
torch.tensor(
[[[[1.0, -4.0, 2.0, 3.0], [5, -8.0, 6.0, 7.0]]]],
[[[[1.0, 2.0, -4.0, 3.0], [5, 6.0, -8.0, 7.0]]]],
dtype=torch.float32,
),
),
Expand All @@ -1489,8 +1489,8 @@ def test_where_Scalar(self) -> None:
torch.tensor(
[
[
[[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]],
[[9.0, -12.0, 10.0, 11.0], [13.0, -16.0, 14.0, 15.0]],
[[1.0, 2.0, -4.0, 3.0], [5.0, 6.0, -8.0, 7.0]],
[[9.0, 10.0, -12.0, 11.0], [13.0, 14.0, -16.0, 15.0]],
]
],
dtype=torch.float32,
Expand All @@ -1512,8 +1512,8 @@ def test_where_Scalar(self) -> None:
torch.tensor(
[
[
[[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]],
[[-10.0, 11.0, 9.0, 12.0], [-14.0, 15.0, 13.0, 16.0]],
[[1.0, 2.0, -4.0, 3.0], [5.0, 6.0, -8.0, 7.0]],
[[-10.0, 9.0, 11.0, 12.0], [-14.0, 13.0, 15.0, 16.0]],
]
],
dtype=torch.float32,
Expand Down
47 changes: 27 additions & 20 deletions backends/cadence/generic/operators/op_rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@ Tensor& rope_out(
const optional<Tensor>& pos,
Tensor& out) {
// Input shape is [1, seq, h, hd / 2, 2] or [1, seq, h, hd]
const auto kSeq = input.size(1);
const auto kH = input.size(2);
const auto kHd = input.numel() / (kSeq * kH);
for (int32_t s = 0; s < kSeq; ++s) {
for (int32_t h = 0; h < kH; ++h) {
for (int32_t hd_o = 0; hd_o < kHd / 2; ++hd_o) {
float x_0 =
input.const_data_ptr<float>()[s * kH * kHd + h * kHd + hd_o * 2];
float x_1 =
input
.const_data_ptr<float>()[s * kH * kHd + h * kHd + hd_o * 2 + 1];
const ssize_t seq_length = input.size(1);
const ssize_t num_heads = input.size(2);
const ssize_t head_dimension = input.numel() / (seq_length * num_heads);
const ssize_t head_dimension_by_two = head_dimension / 2;
for (int32_t s = 0; s < seq_length; ++s) {
for (int32_t h = 0; h < num_heads; ++h) {
for (int32_t hd_o = 0; hd_o < head_dimension_by_two; ++hd_o) {
// Process 2 elements in head dimension at a time.
const float x_0 = input.const_data_ptr<float>()
[s * num_heads * head_dimension +
h * head_dimension + hd_o * 2];
const float x_1 = input.const_data_ptr<float>()
[s * num_heads * head_dimension +
h * head_dimension + hd_o * 2 + 1];
int64_t token_id = s;
if (pos.has_value()) {
if (pos->scalar_type() == ::executorch::aten::ScalarType::Int) {
Expand All @@ -42,17 +45,21 @@ Tensor& rope_out(
token_id = pos.has_value() ? pos->const_data_ptr<int64_t>()[s] : s;
}
}
float sin =
sin_tensor.const_data_ptr<float>()[token_id * kHd / 2 + hd_o];
float cos =
cos_tensor.const_data_ptr<float>()[token_id * kHd / 2 + hd_o];

float out_0 = x_0 * cos - x_1 * sin;
float out_1 = x_0 * sin + x_1 * cos;
out.mutable_data_ptr<float>()[s * kH * kHd + h * kHd + hd_o * 2] =
const float sin = sin_tensor.const_data_ptr<
float>()[token_id * head_dimension_by_two + hd_o];
const float cos = cos_tensor.const_data_ptr<
float>()[token_id * head_dimension_by_two + hd_o];

const float out_0 = x_0 * cos - x_1 * sin;
out.mutable_data_ptr<float>()
[s * num_heads * head_dimension + h * head_dimension + hd_o * 2] =
out_0;
out.mutable_data_ptr<float>()[s * kH * kHd + h * kHd + hd_o * 2 + 1] =
out_1;

const float out_1 = x_0 * sin + x_1 * cos;
out.mutable_data_ptr<float>()
[s * num_heads * head_dimension + h * head_dimension + hd_o * 2 +
1] = out_1;
}
}
}
Expand Down
Loading