Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cumulative sum tensor operation #1722

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_sum_dim(tensor, dim)
}

fn int_cumsum<const D: usize>(tensor: IntTensor<B, D>, dim: usize) -> IntTensor<B, D> {
B::int_cumsum(tensor, dim)
}

fn int_mean<const D: usize>(tensor: IntTensor<B, D>) -> IntTensor<B, 1> {
B::int_mean(tensor)
}
Expand Down
38 changes: 38 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,44 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}

fn float_cumsum<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct CumSumDim;

impl<B: Backend, const D: usize> Backward<B, D, 1> for CumSumDim {
type State = (Shape<D>, usize);

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let (shape, dim) = ops.state;

unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let cumsum_grad = B::float_cumsum(grad.clone(), dim);
B::float_flip(cumsum_grad.clone(), &[dim])
});
}
}

match CumSumDim
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(B::float_shape(&tensor.primitive), dim),
B::float_cumsum(tensor.primitive, dim),
),
OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)),
}
}

fn float_argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<B, D> {
B::float_argmax(tensor.primitive, dim)
}
Expand Down
25 changes: 25 additions & 0 deletions crates/burn-autodiff/src/tests/cumsum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#[burn_tensor_testgen::testgen(ad_cumsum)]
mod tests {
use super::*;
use burn_tensor::Data;

#[test]
fn should_diff_cumsum() {
let device = Default::default();
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);
// Original Tensor
let tensor_0 = TestAutodiffTensor::from_data(data, &device).require_grad();
// Cumsum Tensor
let dim = 1;
let tensor_1 = tensor_0.clone().cumsum(dim);
// Fake loss
let loss = tensor_1.clone().sum();
// Gradients with respect to the original tensor
let grads = loss.backward();
// let grads = tensor_1.backward();
let grad_0 = tensor_0.grad(&grads).unwrap();
// Gradient is correct
let grad_0_expected = Data::from([[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]]);
grad_0.into_data().assert_approx_eq(&grad_0_expected, 2);
}
}
2 changes: 2 additions & 0 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod conv_transpose1d;
mod conv_transpose2d;
mod cos;
mod cross_entropy;
mod cumsum;
mod div;
mod erf;
mod exp;
Expand Down Expand Up @@ -112,6 +113,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_cumsum!();
burn_autodiff::testgen_ad_recip!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}

fn int_cumsum<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!("prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)")
}
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}

fn float_cumsum<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}

fn float_mean_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ fn main() {
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/pow/cumsum.onnx")
.input("tests/pow/cumsum_int.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
Expand Down
47 changes: 47 additions & 0 deletions crates/burn-import/onnx-tests/tests/cumsum/cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/cumsum/cumsum.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# self.b = 5.0

def forward(self, x, d):
# cumulative sum of a tensor along dimension d
x = x.cumsum(d)

return x


def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "cumsum.onnx"
dummy_input = torch.tensor([[0,1,2], [3,4,5], [6, 7, 8]], dtype=torch.float32, device=device)

dim= 1

torch.onnx.export(
model, (dummy_input, dim), onnx_name, verbose=False, opset_version=16
)

print(f"Finished exporting model to {onnx_name}")

# Output some test data for use in the test
test_input = torch.tensor([[0,1,2], [3,4,5], [6, 7, 8]], dtype=torch.float32, device=device)

print(f"Test input data: {test_input}, {dim}")
output = model.forward(test_input, dim)
print(f"Test output data: {output}")



if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions crates/burn-import/onnx-tests/tests/cumsum/cumsum_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/cumsum/cumsum.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# self.b = 5.0

def forward(self, x, d):
# cumulative sum of a tensor along dimension d
x = x.cumsum(d)

return x


def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "cumsum.onnx"
test_input = torch.tensor([[0,1,2], [3,4,5], [6, 7, 8]], dtype=torch.int32, device=device)

dim= 1

torch.onnx.export(
model, (test_input, dim), onnx_name, verbose=False, opset_version=16
)

print(f"Finished exporting model to {onnx_name}")


print(f"Test input data: {test_input}, {dim}")
output = model.forward(test_input, dim)
print(f"Test output data: {output}")



if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,35 @@ mod tests {

assert_eq!(output.to_data(), expected);
}
#[test]
fn cumsum() {
let device = Default::default();
let model: pow::Model<Backend> = cumsum::Model::new(&device);

let input1 =
Tensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]], &device);
let input2 = 1;

let output = model.forward(input1, input2);

let expected = Data::from([[0.0, 1.0, 3.0], [3.0, 7.0, 12.0], [6.0, 13.0, 21.0]]);

assert_eq!(output.to_data(), expected);
}
#[test]
fn cumsum_int() {
let device = Default::default();
let model: pow::Model<Backend> = cumsum_int::Model::new(&device);

let input1 = Tensor::from_ints([[0, 1, 2], [3, 4, 5], [6, 7, 8]], &device);
let input2 = 1;

let output = model.forward(input1, input2);

let expected = Data::from([[0, 1, 3], [3, 7, 12], [6, 13, 21]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn unsqueeze() {
Expand Down
35 changes: 35 additions & 0 deletions crates/burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub enum BinaryType {
Equal,
Powf,
Powi,
FloatCumsum,
IntCumsum,
}

impl BinaryType {
Expand All @@ -26,6 +28,8 @@ impl BinaryType {
BinaryType::Equal => "equal",
BinaryType::Powi => "powi",
BinaryType::Powf => "powf",
BinaryType::FloatCumsum => "float_cumsum",
BinaryType::IntCumsum => "int_cumsum",
}
}
}
Expand Down Expand Up @@ -173,6 +177,28 @@ impl BinaryNode {
};
Self::new(lhs, rhs, output, BinaryType::Powi, Arc::new(function))
}
pub(crate) fn float_cumsum(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Scalar(_)) => {
move |lhs, rhs| quote! { #lhs.float_cumsum(#rhs) }
}
_ => panic!("cumsum is is supported for tensor-scalar input pairs only"),
};
Self::new(
lhs,
rhs,
output,
BinaryType::FloatCumsum,
Arc::new(function),
)
}
pub(crate) fn int_cumsum(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.int_cumsum(#rhs) },
_ => panic!("cumsum is is supported for tensor-scalar input pairs only"),
};
Self::new(lhs, rhs, output, BinaryType::IntCumsum, Arc::new(function))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -312,6 +338,15 @@ mod tests {
fn test_binary_codegen_powf_scalar() {
test_binary_operator_on_tensor_and_scalar!(powf, powf_scalar);
}
#[test]
fn test_binary_codegen_int_cumsum() {
test_binary_operator_on_tensor_and_scalar!(int_cumsum, int_cumsum_scalar);
}

#[test]
fn test_binary_codegen_float_cumsum() {
test_binary_operator_on_tensor_and_scalar!(float_cumsum, float_cumsum_scalar);
}

#[test]
fn test_binary_codegen_div() {
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Transpose => same_as_input(node),
NodeType::Unsqueeze => unsqueeze_update_output(node),
NodeType::Pow => same_as_input(node),
NodeType::CumSum => same_as_input(node),
NodeType::LeakyRelu => same_as_input(node),
NodeType::Where => where_update_outputs(node),
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
Expand Down
14 changes: 14 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ impl OnnxGraph {
graph.register(Self::conv_transpose2d_conversion(node))
}
NodeType::Pow => graph.register(Self::pow_conversion(node)),
NodeType::CumSum => graph.register(Self::cumsum_conversion(node)),
NodeType::Unsqueeze => graph.register(Self::unsqueeze_conversion(node)),
NodeType::Where => graph.register(Self::where_conversion(node)),
NodeType::Sign => graph.register(Self::sign_conversion(node)),
Expand Down Expand Up @@ -779,6 +780,19 @@ impl OnnxGraph {
_ => panic!("pow function only supports RHS scalar or tensor types"),
}
}
fn cumsum_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
match &lhs {
Type::Tensor(x) => match x.kind {
TensorKind::Int => BinaryNode::int_cumsum(lhs, rhs, output),
TensorKind::Float => BinaryNode::float_cumsum(lhs, rhs, output),
_ => panic!("cumsum function requires LHS to be int or float type"),
},
_ => panic!("cumsum function only supports LHS tensor type"),
}
}

fn sign_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
Expand Down
13 changes: 12 additions & 1 deletion crates/burn-ndarray/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn_tensor::{Distribution, Reader};

use burn_tensor::ElementConversion;
use core::ops::Range;
use ndarray::IntoDimension;
use ndarray::{Axis, IntoDimension};

// Current crate
use crate::element::ExpElement;
Expand Down Expand Up @@ -286,6 +286,17 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
NdArrayMathOps::sum_dim(tensor, dim)
}

fn int_cumsum<const D: usize>(
tensor: NdArrayTensor<i64, D>,
dim: usize,
) -> NdArrayTensor<i64, D> {
let mut array = tensor.array.clone().into_owned();

array.accumulate_axis_inplace(Axis(dim), |&prev, curr| *curr += prev);

NdArrayTensor::new(array.to_shared())
}

fn int_prod<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, 1> {
NdArrayMathOps::prod(tensor)
}
Expand Down