diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 5e3e44a285..d5dd844c8b 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -161,6 +161,10 @@ impl IntTensorOps for Autodiff { B::int_sum_dim(tensor, dim) } + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_cumsum(tensor, dim) + } + fn int_mean(tensor: IntTensor) -> IntTensor { B::int_mean(tensor) } diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index d8dfe17102..fccc414b75 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -1600,6 +1600,44 @@ impl FloatTensorOps for Autodiff } } + fn float_cumsum( + tensor: FloatTensor, + dim: usize, + ) -> FloatTensor { + #[derive(Debug)] + struct CumSumDim; + + impl Backward for CumSumDim { + type State = (Shape, usize); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + let (shape, dim) = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let cumsum_grad = B::float_cumsum(grad.clone(), dim); + B::float_flip(cumsum_grad.clone(), &[dim]) + }); + } + } + + match CumSumDim + .prepare::([tensor.node]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + (B::float_shape(&tensor.primitive), dim), + B::float_cumsum(tensor.primitive, dim), + ), + OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)), + } + } + fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { B::float_argmax(tensor.primitive, dim) } diff --git a/crates/burn-autodiff/src/tests/cumsum.rs b/crates/burn-autodiff/src/tests/cumsum.rs new file mode 100644 index 0000000000..7ca5eb41c0 --- /dev/null +++ b/crates/burn-autodiff/src/tests/cumsum.rs @@ -0,0 +1,25 @@ +#[burn_tensor_testgen::testgen(ad_cumsum)] +mod tests { + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_cumsum() { + let device = Default::default(); + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + // Original Tensor + let tensor_0 = TestAutodiffTensor::from_data(data, &device).require_grad(); + // Cumsum Tensor + let dim = 1; + let tensor_1 = tensor_0.clone().cumsum(dim); + // Fake loss + let loss = tensor_1.clone().sum(); + // Gradients with respect to the original tensor + let grads = loss.backward(); + // let grads = tensor_1.backward(); + let grad_0 = tensor_0.grad(&grads).unwrap(); + // Gradient is correct + let grad_0_expected = Data::from([[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]]); + grad_0.into_data().assert_approx_eq(&grad_0_expected, 2); + } +} diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index 1d79c07ad5..47fc3122b5 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -19,6 +19,7 @@ mod conv_transpose1d; mod conv_transpose2d; mod cos; mod cross_entropy; +mod cumsum; mod div; mod erf; mod exp; @@ -112,6 +113,7 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_mul!(); burn_autodiff::testgen_ad_neg!(); burn_autodiff::testgen_ad_powf!(); + burn_autodiff::testgen_ad_cumsum!(); burn_autodiff::testgen_ad_recip!(); burn_autodiff::testgen_ad_reshape!(); burn_autodiff::testgen_ad_sin!(); diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 7081b5d660..d9e8a9d044 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -321,6 +321,13 @@ impl IntTensorOps for Candle( + tensor: IntTensor, + dim: usize, + ) -> IntTensor { + CandleTensor::new(tensor.tensor.cumsum(dim).unwrap()) + } + fn int_prod(tensor: IntTensor) -> IntTensor { todo!("prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)") } diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index a15a671cb3..50c29152a0 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -373,6 +373,13 @@ impl FloatTensorOps for Candle CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) } + fn float_cumsum( + tensor: FloatTensor, + dim: usize, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.cumsum(dim).unwrap()) + } + fn float_mean_dim( tensor: FloatTensor, dim: usize, diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 94b32209e9..fd38e72ed4 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -55,6 +55,8 @@ fn main() { .input("tests/conv_transpose2d/conv_transpose2d.onnx") .input("tests/pow/pow.onnx") .input("tests/pow/pow_int.onnx") + .input("tests/pow/cumsum.onnx") + .input("tests/pow/cumsum_int.onnx") .input("tests/unsqueeze/unsqueeze.onnx") .input("tests/unsqueeze/unsqueeze_opset16.onnx") .input("tests/unsqueeze/unsqueeze_opset11.onnx") diff --git a/crates/burn-import/onnx-tests/tests/cumsum/cumsum.py b/crates/burn-import/onnx-tests/tests/cumsum/cumsum.py new file mode 100644 index 0000000000..ee271e3c6b --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/cumsum/cumsum.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/cumsum/cumsum.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + # self.b = 5.0 + + def forward(self, x, d): + # cumulative sum of a tensor along dimension d + x = x.cumsum(d) + + return x + + +def main(): + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "cumsum.onnx" + dummy_input = torch.tensor([[0,1,2], [3,4,5], [6, 7, 8]], dtype=torch.float32, device=device) + + dim= 1 + + torch.onnx.export( + model, (dummy_input, dim), onnx_name, verbose=False, opset_version=16 + ) + + print(f"Finished exporting model to {onnx_name}") + + # Output some test data for use in the test + test_input = torch.tensor([[0,1,2], [3,4,5], [6, 7, 8]], dtype=torch.float32, device=device) + + print(f"Test input data: {test_input}, {dim}") + output = model.forward(test_input, dim) + print(f"Test output data: {output}") + + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/cumsum/cumsum_int.py b/crates/burn-import/onnx-tests/tests/cumsum/cumsum_int.py new file mode 100644 index 0000000000..dd344d0b5f --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/cumsum/cumsum_int.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/cumsum/cumsum.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + # self.b = 5.0 + + def forward(self, x, d): + # cumulative sum of a tensor along dimension d + x = x.cumsum(d) + + return x + + +def main(): + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "cumsum.onnx" + test_input = torch.tensor([[0,1,2], [3,4,5], [6, 7, 8]], dtype=torch.int32, device=device) + + dim= 1 + + torch.onnx.export( + model, (test_input, dim), onnx_name, verbose=False, opset_version=16 + ) + + print(f"Finished exporting model to {onnx_name}") + + + print(f"Test input data: {test_input}, {dim}") + output = model.forward(test_input, dim) + print(f"Test output data: {output}") + + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 34ddfa5f87..e34f4922e1 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -1049,6 +1049,35 @@ mod tests { assert_eq!(output.to_data(), expected); } + #[test] + fn cumsum() { + let device = Default::default(); + let model: pow::Model = cumsum::Model::new(&device); + + let input1 = + Tensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]], &device); + let input2 = 1; + + let output = model.forward(input1, input2); + + let expected = Data::from([[0.0, 1.0, 3.0], [3.0, 7.0, 12.0], [6.0, 13.0, 21.0]]); + + assert_eq!(output.to_data(), expected); + } + #[test] + fn cumsum_int() { + let device = Default::default(); + let model: pow::Model = cumsum_int::Model::new(&device); + + let input1 = Tensor::from_ints([[0, 1, 2], [3, 4, 5], [6, 7, 8]], &device); + let input2 = 1; + + let output = model.forward(input1, input2); + + let expected = Data::from([[0, 1, 3], [3, 7, 12], [6, 13, 21]]); + + assert_eq!(output.to_data(), expected); + } #[test] fn unsqueeze() { diff --git a/crates/burn-import/src/burn/node/binary.rs b/crates/burn-import/src/burn/node/binary.rs index 8221ba300e..9392bac094 100644 --- a/crates/burn-import/src/burn/node/binary.rs +++ b/crates/burn-import/src/burn/node/binary.rs @@ -14,6 +14,8 @@ pub enum BinaryType { Equal, Powf, Powi, + FloatCumsum, + IntCumsum, } impl BinaryType { @@ -26,6 +28,8 @@ impl BinaryType { BinaryType::Equal => "equal", BinaryType::Powi => "powi", BinaryType::Powf => "powf", + BinaryType::FloatCumsum => "float_cumsum", + BinaryType::IntCumsum => "int_cumsum", } } } @@ -173,6 +177,28 @@ impl BinaryNode { }; Self::new(lhs, rhs, output, BinaryType::Powi, Arc::new(function)) } + pub(crate) fn float_cumsum(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Scalar(_)) => { + move |lhs, rhs| quote! { #lhs.float_cumsum(#rhs) } + } + _ => panic!("cumsum is is supported for tensor-scalar input pairs only"), + }; + Self::new( + lhs, + rhs, + output, + BinaryType::FloatCumsum, + Arc::new(function), + ) + } + pub(crate) fn int_cumsum(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.int_cumsum(#rhs) }, + _ => panic!("cumsum is is supported for tensor-scalar input pairs only"), + }; + Self::new(lhs, rhs, output, BinaryType::IntCumsum, Arc::new(function)) + } } #[cfg(test)] @@ -312,6 +338,15 @@ mod tests { fn test_binary_codegen_powf_scalar() { test_binary_operator_on_tensor_and_scalar!(powf, powf_scalar); } + #[test] + fn test_binary_codegen_int_cumsum() { + test_binary_operator_on_tensor_and_scalar!(int_cumsum, int_cumsum_scalar); + } + + #[test] + fn test_binary_codegen_float_cumsum() { + test_binary_operator_on_tensor_and_scalar!(float_cumsum, float_cumsum_scalar); + } #[test] fn test_binary_codegen_div() { diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 0775e631b8..022d1a40e9 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -58,6 +58,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Transpose => same_as_input(node), NodeType::Unsqueeze => unsqueeze_update_output(node), NodeType::Pow => same_as_input(node), + NodeType::CumSum => same_as_input(node), NodeType::LeakyRelu => same_as_input(node), NodeType::Where => where_update_outputs(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index a736b564b5..d6335fc71f 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -278,6 +278,7 @@ impl OnnxGraph { graph.register(Self::conv_transpose2d_conversion(node)) } NodeType::Pow => graph.register(Self::pow_conversion(node)), + NodeType::CumSum => graph.register(Self::cumsum_conversion(node)), NodeType::Unsqueeze => graph.register(Self::unsqueeze_conversion(node)), NodeType::Where => graph.register(Self::where_conversion(node)), NodeType::Sign => graph.register(Self::sign_conversion(node)), @@ -779,6 +780,19 @@ impl OnnxGraph { _ => panic!("pow function only supports RHS scalar or tensor types"), } } + fn cumsum_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.first().unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + match &lhs { + Type::Tensor(x) => match x.kind { + TensorKind::Int => BinaryNode::int_cumsum(lhs, rhs, output), + TensorKind::Float => BinaryNode::float_cumsum(lhs, rhs, output), + _ => panic!("cumsum function requires LHS to be int or float type"), + }, + _ => panic!("cumsum function only supports LHS tensor type"), + } + } fn sign_conversion(node: Node) -> UnaryNode { let input = node.inputs.first().unwrap().to_type(); diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index fba8d30c0e..2ad2225f36 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -7,7 +7,7 @@ use burn_tensor::{Distribution, Reader}; use burn_tensor::ElementConversion; use core::ops::Range; -use ndarray::IntoDimension; +use ndarray::{Axis, IntoDimension}; // Current crate use crate::element::ExpElement; @@ -286,6 +286,17 @@ impl IntTensorOps for NdArray { NdArrayMathOps::sum_dim(tensor, dim) } + fn int_cumsum( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + let mut array = tensor.array.clone().into_owned(); + + array.accumulate_axis_inplace(Axis(dim), |&prev, curr| *curr += prev); + + NdArrayTensor::new(array.to_shared()) + } + fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::prod(tensor) } diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index ce638ad66a..4683dc56f2 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -1,7 +1,7 @@ // Language use alloc::vec::Vec; use core::ops::Range; -use ndarray::IntoDimension; +use ndarray::{Axis, IntoDimension}; // Current crate use super::{matmul::matmul, NdArrayMathOps, NdArrayOps}; @@ -338,6 +338,17 @@ impl FloatTensorOps for NdArray { NdArrayMathOps::sum_dim(tensor, dim) } + fn float_cumsum( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + let mut array = tensor.array.clone().into_owned(); + + array.accumulate_axis_inplace(Axis(dim), |&prev, curr| *curr += prev); + + NdArrayTensor::new(array.to_shared()) + } + fn float_argmax( tensor: NdArrayTensor, dim: usize, diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index c307e30dac..112d23a551 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -335,6 +335,10 @@ impl TchOps { ) } + pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor { + TchTensor::from_existing(tensor.tensor.cumsum(dim as i64, E::KIND), tensor.storage) + } + pub fn prod(tensor: TchTensor) -> TchTensor { let tensor = tensor.tensor.prod(E::KIND); TchTensor::new(tensor) diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 460c341632..7ec4d4742e 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -270,6 +270,10 @@ impl IntTensorOps for LibTorch { TchOps::sum_dim(tensor, dim) } + fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::cumsum(tensor, dim) + } + fn int_prod(tensor: TchTensor) -> TchTensor { TchOps::prod(tensor) } diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index 86f01248e2..a237302586 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -331,6 +331,10 @@ impl FloatTensorOps for LibTorch { TchOps::sum_dim(tensor, dim) } + fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::cumsum(tensor, dim) + } + fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::mean_dim(tensor, dim) } diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index e7390689ce..c91222647a 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -803,6 +803,22 @@ impl TensorCheck { check } + /// Checks running dimension such as cumulative sum + pub(crate) fn running_dim(ops: &str, dim: usize) -> Self { + let mut check = Self::Ok; + + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't perform a running calculation on a tensor with ({D}) dimensions on axis ({dim})" + )), + ); + } + + check + } + pub(crate) fn sort_dim(ops: &str, dim: usize) -> Self { let mut check = Self::Ok; diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 40674d6d64..b96998af4f 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -142,6 +142,13 @@ where Self::new(K::sum_dim(self.primitive, dim)) } + /// Perform a cumulative sum on all elements along the given *dimension* or *axis* + /// in the tensor with the sum operation. + pub fn cumsum(self, dim: usize) -> Self { + check!(TensorCheck::running_dim::("Sum", dim)); + Self::new(K::cumsum(self.primitive, dim)) + } + /// Aggregate all elements along the given *dimension* or *axis* /// in the tensor with the product operation. pub fn prod(self) -> Tensor { @@ -1173,6 +1180,27 @@ where /// which is more high-level and designed for public use. fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + /// Performs cumulative sum across all the elements of the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to perform cumulative sum on. + /// * `dim` - The dimension along which to perform cumulative sum. + /// + /// # Returns + /// + /// The cumulative sum of all the elements of the tensor along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For performing cumulative sum across all the elements of a tensor along a dimension, users should prefer the [Tensor::cumsum](Tensor::cumsum) function, + /// which is more high-level and designed for public use. + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + /// Computes the product of all the elements of the tensor. /// /// # Arguments @@ -2176,6 +2204,10 @@ impl Numeric for Int { B::int_sum_dim(tensor, dim) } + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_cumsum(tensor, dim) + } + fn prod(tensor: Self::Primitive) -> Self::Primitive<1> { B::int_prod(tensor) } @@ -2521,6 +2553,10 @@ impl Numeric for Float { B::float_sum_dim(tensor, dim) } + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::float_cumsum(tensor, dim) + } + fn prod(tensor: Self::Primitive) -> Self::Primitive<1> { B::float_prod(tensor) } diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 8f3fc56d51..a84c6cbb42 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -770,6 +770,18 @@ pub trait IntTensorOps { /// The sum of all elements in the tensor along the dimension. fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor; + /// Cumulative Sum of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to perform cumulative sum on. + /// * `dim` - The dimension along which to perform cumulative sum. + /// + /// # Returns + /// + /// A tensor with the cumulative sum of all elements in `tensor` along `dim`. + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor; + /// Computes the product of all elements in the tensor. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 27b28b3070..ac85e6e5fa 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -842,6 +842,19 @@ pub trait FloatTensorOps { /// A tensor with the sum of all elements in `tensor` along `dim`. fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; + /// Cumulative Sum of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to perform cumulative sum on. + /// * `dim` - The dimension along which to perform cumulative sum. + /// + /// # Returns + /// + /// A tensor with the cumulative sum of all elements in `tensor` along `dim`. + fn float_cumsum(tensor: FloatTensor, dim: usize) + -> FloatTensor; + /// Product of all elements in a tensor. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 527860306f..145fe6b402 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -50,6 +50,7 @@ macro_rules! testgen_all { burn_tensor::testgen_close!(); burn_tensor::testgen_cos!(); burn_tensor::testgen_create_like!(); + burn_tensor::testgen_cumsum!(); burn_tensor::testgen_div!(); burn_tensor::testgen_erf!(); burn_tensor::testgen_exp!(); diff --git a/crates/burn-tensor/src/tests/ops/cumsum.rs b/crates/burn-tensor/src/tests/ops/cumsum.rs new file mode 100644 index 0000000000..815dc8b157 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/cumsum.rs @@ -0,0 +1,18 @@ +#[burn_tensor_testgen::testgen(cumsum)] +mod tests { + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn should_cumsum_over_dim() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let tensor = Tensor::::from_data(data, &Default::default()); + + let dim = 1; + + let data_actual = tensor.cumsum(dim).into_data(); + let data_expected = Data::from([[0.0, 1.0, 3.0], [3.0, 7.0, 12.0], [6.0, 13.0, 21.0]]); + + data_expected.assert_approx_eq(&data_actual, 3); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index 1de78ac2a7..22268e34c3 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -15,6 +15,7 @@ mod clamp; mod close; mod cos; mod create_like; +mod cumsum; mod div; mod erf; mod exp;