Skip to content

Commit

Permalink
Feat: Add movedim tensor operator (#1876)
Browse files Browse the repository at this point in the history
* ✨ (burn-tensor): add movedim function to tensor API

---------

Co-authored-by: Georgy Andreev <g.andreev@insilicomedicine.com>
  • Loading branch information
LilDojd and Georgy Andreev committed Jun 14, 2024
1 parent 47a8127 commit b71c300
Show file tree
Hide file tree
Showing 6 changed files with 486 additions and 0 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
Expand Down
115 changes: 115 additions & 0 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,60 @@ where
Tensor::new(K::permute(self.primitive, transformed_axes))
}

/// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
///
/// Other dimensions of input that are not explicitly moved remain in their original order and appear
/// at the positions not specified in destination.
///
/// # Arguments
///
/// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
/// The values can be negative, in which case they are used as an offset from the end.
///
/// * `dst` - Destination positions for each of the original dims. These must also be unique.
///
/// # Panics
///
/// - If the source and destination dimensions are not of the same length.
/// - If the source and destination vectors contain duplicate values.
/// - If the source and destination vectors contain values that are out of bounds.
///
/// # Returns
///
/// The tensor with the dimensions moved.
// This is a semantic sugar for `permute`. It is used widely enough, so we define a separate Op
// for it
pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
let source_dims = src.into_dim_vec::<D>();
let destination_dims = dst.into_dim_vec::<D>();

check!(TensorCheck::movedim_args_length(
&source_dims,
&destination_dims
));

let mut m = [-1; D];
for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
m[d] = s as isize;
}
let mut axes: [isize; D] = [0; D];
let mut source_i = 0;
for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
*item = if m[dest_i] != -1 {
m[dest_i]
} else {
while source_dims.contains(&source_i) {
source_i += 1;
}
let result = source_i as isize;
source_i += 1;
result
};
}

self.permute(axes)
}

/// Reverse the order of elements in the tensor along the given dimensions.
///
/// # Arguments
Expand Down Expand Up @@ -1983,6 +2037,67 @@ impl<B: Backend> BasicOps<B> for Bool {
}
}

/// Trait used for movedim arguments
pub trait MovedimArgs {
/// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
}

impl MovedimArgs for Vec<i32> {
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
let set = self
.iter()
.map(|&dim| {
if dim < 0 {
(D as i32 + dim) as usize
} else {
dim as usize
}
})
.collect::<Vec<usize>>();
check!(TensorCheck::movedim_args_vec::<D>(&set));

set
}
}

impl MovedimArgs for Vec<usize> {
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
check!(TensorCheck::movedim_args_vec::<D>(&self));
self
}
}

impl MovedimArgs for usize {
#[allow(clippy::vec_init_then_push)]
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
check!(TensorCheck::movedim_args_usize::<D>(self));

let mut set = Vec::with_capacity(1);
set.push(self);

set
}
}

impl MovedimArgs for i32 {
#[allow(clippy::vec_init_then_push)]
fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
check!(TensorCheck::movedim_args_i32::<D>(self));

let dim = if self < 0 {
(D as i32 + self) as usize
} else {
self as usize
};

let mut set = Vec::with_capacity(1);
set.push(dim);

set
}
}

/// Trait used for reshape arguments.
pub trait ReshapeArgs<const D2: usize> {
/// Converts to a shape.
Expand Down
130 changes: 130 additions & 0 deletions crates/burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,98 @@ impl TensorCheck {
check
}

pub(crate) fn movedim_args_usize<const D: usize>(dim: usize) -> Self {
let mut check = Self::Ok;

if dim >= D {
check = check.register(
"Movedim",
TensorError::new(
"The given dimension exceeds the number of dimensions of the current tensor.",
)
.details(format!(
"Current tensor has {D} dimensions, but the given dimension is {dim}.",
)),
);
}

check
}

pub(crate) fn movedim_args_i32<const D: usize>(dim: i32) -> Self {
let mut check = Self::Ok;

if dim < -(D as i32) || dim >= D as i32 {
check = check.register(
"Movedim",
TensorError::new(
"The given dimension is out of bounds for the current tensor dimensions.",
)
.details(format!(
"Current tensor has {D} dimensions, but the given dimension is {dim}.",
)),
);
}

check
}

pub(crate) fn movedim_args_vec<const D: usize>(dims: &Vec<usize>) -> Self {
let mut check = Self::Ok;

// Check out of bounds
if dims.iter().any(|&x| x >= D) {
check = check.register(
"Movedim",
TensorError::new("The given dimensions are out of bounds.").details(format!(
"Current tensor has {D} dimensions, but the given dimensions are {:?}.",
dims
)),
);
}

// Check there are no duplicates
for (i, &dim_i) in dims.iter().enumerate() {
for &dim_j in dims.iter().skip(i + 1) {
if dim_i == dim_j {
check = check.register(
"Movedim",
TensorError::new("The given dimensions contain duplicates.").details(
format!(
"The dimension {} is duplicated in the given dimensions {:?}.",
dim_i, dims
),
),
);
}
}
}

check
}

pub(crate) fn movedim_args_length(
source_dims: &Vec<usize>,
destination_dims: &Vec<usize>,
) -> Self {
let mut check = Self::Ok;

if source_dims.len() != destination_dims.len() {
check = check.register(
"Movedim",
TensorError::new(
"The number of dimensions in source and destination must be equal.",
)
.details(format!(
"Source dimensions: {:?}, Destination dimensions: {:?}.",
source_dims, destination_dims
)),
)
}

check
}

pub(crate) fn flatten<const D1: usize, const D2: usize>(
start_dim: usize,
end_dim: usize,
Expand Down Expand Up @@ -1104,4 +1196,42 @@ mod tests {
&8
));
}

#[test]
#[should_panic]
fn movedim_args_out_of_bounds() {
check!(TensorCheck::movedim_args_usize::<3>(5));
}

#[test]
fn movedim_args_i32() {
check!(TensorCheck::movedim_args_i32::<3>(-3));
}

#[test]
#[should_panic]
fn movedim_args_too_negative() {
check!(TensorCheck::movedim_args_i32::<3>(-4));
}

#[test]
#[should_panic]
fn movedim_args_vec_out_of_bounds() {
check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 3]));
}

#[test]
#[should_panic]
fn movedim_args_vec_duplicates() {
check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 1]));
}

#[test]
#[should_panic]
fn movedim_args_length() {
check!(TensorCheck::movedim_args_length(
&vec![0, 1],
&vec![0, 1, 2]
));
}
}
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_any!();
burn_tensor::testgen_all_op!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_movedim!();
burn_tensor::testgen_flip!();
burn_tensor::testgen_bool!();
burn_tensor::testgen_argwhere_nonzero!();
Expand Down
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod map_comparison;
mod mask;
mod matmul;
mod maxmin;
mod movedim;
mod mul;
mod narrow;
mod neg;
Expand Down
Loading

0 comments on commit b71c300

Please sign in to comment.