diff --git a/crates/burn-core/src/nn/rope_encoding.rs b/crates/burn-core/src/nn/rope_encoding.rs index 6eb732e4b1..b8f4e97903 100644 --- a/crates/burn-core/src/nn/rope_encoding.rs +++ b/crates/burn-core/src/nn/rope_encoding.rs @@ -104,6 +104,22 @@ impl RotaryEncoding { /// /// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension. pub fn forward(&self, x: Tensor) -> Tensor { + self.apply(x, 0) + } + + /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model) + /// + /// Arguments: + /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors + /// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim) + /// respectively. + /// * `start` - Sequence start position index. + /// + /// Returns: + /// * Output tensor with the same shape as input tensor after applying rotary encoding. + /// + /// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension. + pub fn apply(&self, x: Tensor, start: usize) -> Tensor { assert!( D >= 2, "Input tensor must have at least 2 dimensions for sequence length and hidden dimension" @@ -127,7 +143,11 @@ impl RotaryEncoding { .reshape([dummy_dim_size, seq_len, d_model / 2, 2]) .matmul(sign_tensor.unsqueeze()) .reshape([dummy_dim_size, seq_len, d_model, 2]) - * self.freq_complex.clone().slice([0..seq_len]).unsqueeze(); + * self + .freq_complex + .clone() + .slice([start..start + seq_len]) + .unsqueeze(); // Sum the real and imaginary components to get output tensor and reshape to original shape out.sum_dim(D - 1).reshape(input_shape)