Skip to content

Commit

Permalink
Add seq start position when applying RoPE encoding (#1796)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed May 22, 2024
1 parent 0918cf0 commit b466fd7
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion crates/burn-core/src/nn/rope_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ impl<B: Backend> RotaryEncoding<B> {
///
/// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
self.apply(x, 0)
}

/// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
///
/// Arguments:
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
/// respectively.
/// * `start` - Sequence start position index.
///
/// Returns:
/// * Output tensor with the same shape as input tensor after applying rotary encoding.
///
/// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
assert!(
D >= 2,
"Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
Expand All @@ -127,7 +143,11 @@ impl<B: Backend> RotaryEncoding<B> {
.reshape([dummy_dim_size, seq_len, d_model / 2, 2])
.matmul(sign_tensor.unsqueeze())
.reshape([dummy_dim_size, seq_len, d_model, 2])
* self.freq_complex.clone().slice([0..seq_len]).unsqueeze();
* self
.freq_complex
.clone()
.slice([start..start + seq_len])
.unsqueeze();

// Sum the real and imaginary components to get output tensor and reshape to original shape
out.sum_dim(D - 1).reshape(input_shape)
Expand Down

0 comments on commit b466fd7

Please sign in to comment.