From 55f7a21530904ac0df346a679c4fc54d09404580 Mon Sep 17 00:00:00 2001 From: Zsombor Gegesy Date: Sun, 12 Nov 2023 23:24:29 +0100 Subject: [PATCH 1/6] Implement tensor.recip() function to calculate elementwise reciprocals --- burn-autodiff/src/ops/tensor.rs | 17 +++++++++++++++++ burn-candle/src/lib.rs | 1 + burn-candle/src/ops/tensor.rs | 4 ++++ burn-fusion/src/graph/ops.rs | 7 +++++++ burn-fusion/src/ops/float.rs | 15 +++++++++++++++ burn-ndarray/src/ops/base.rs | 7 +++++++ burn-ndarray/src/ops/tensor.rs | 4 ++++ burn-tch/src/ops/tensor.rs | 4 ++++ burn-tensor/src/tensor/api/float.rs | 5 +++++ burn-tensor/src/tensor/ops/tensor.rs | 3 +++ burn-tensor/src/tests/mod.rs | 1 + burn-tensor/src/tests/ops/mod.rs | 1 + burn-tensor/src/tests/ops/recip.rs | 16 ++++++++++++++++ burn-wgpu/src/ops/float_ops.rs | 13 +++++++++++++ 14 files changed, 98 insertions(+) create mode 100644 burn-tensor/src/tests/ops/recip.rs diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index 406de5d6a3..0c2e71f50e 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -435,6 +435,23 @@ impl TensorOps for Autodiff { .stateless(B::neg(tensor.primitive)) } + fn recip(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Recip; + + impl Backward for Recip { + type State = (); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + unary::(ops.parents, ops.node, grads, |grad| B::recip(grad)); + } + } + + Recip + .prepare([tensor.node], [tensor.graph]) + .stateless(B::recip(tensor.primitive)) + } + fn swap_dims( tensor: FloatTensor, dim1: usize, diff --git a/burn-candle/src/lib.rs b/burn-candle/src/lib.rs index b508ed0095..c3b1b07a44 100644 --- a/burn-candle/src/lib.rs +++ b/burn-candle/src/lib.rs @@ -56,6 +56,7 @@ mod tests { burn_tensor::testgen_arg!(); burn_tensor::testgen_cast!(); burn_tensor::testgen_cat!(); + burn_tensor::testgen_recip!(); burn_tensor::testgen_clamp!(); burn_tensor::testgen_cos!(); // burn_tensor::testgen_div!(); diff --git a/burn-candle/src/ops/tensor.rs b/burn-candle/src/ops/tensor.rs index 2242f07762..3f096e92f4 100644 --- a/burn-candle/src/ops/tensor.rs +++ b/burn-candle/src/ops/tensor.rs @@ -442,4 +442,8 @@ impl TensorOps for Candle FloatTensor { CandleTensor::new(tensor.tensor.clamp(min, max).unwrap()) } + + fn recip(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.recip().unwrap()) + } } diff --git a/burn-fusion/src/graph/ops.rs b/burn-fusion/src/graph/ops.rs index fe16ef3a06..5d7b60d413 100644 --- a/burn-fusion/src/graph/ops.rs +++ b/burn-fusion/src/graph/ops.rs @@ -100,6 +100,11 @@ pub enum FloatOpsDescription { (TensorDescription, Distribution>), Box>)>>, ), + /// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip). + Recip( + UnaryOpsDescription, + Box>, + ), } /// Operation description specific to module. @@ -1252,6 +1257,7 @@ impl FloatOpsDescription { FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input), FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input), FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input), + FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input), FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs), FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input), FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input), @@ -1268,6 +1274,7 @@ impl FloatOpsDescription { FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles), FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles), FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles), + FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles), FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles), FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles), FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles), diff --git a/burn-fusion/src/ops/float.rs b/burn-fusion/src/ops/float.rs index 5a1b3bdbc6..af56fc460d 100644 --- a/burn-fusion/src/ops/float.rs +++ b/burn-fusion/src/ops/float.rs @@ -1391,6 +1391,21 @@ impl TensorOps for Fusion { out } + fn recip(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(Recip, B::recip); + + let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + out.client + .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip( + UnaryOpsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }, + Box::new(Recip::), + ))); + out + } + fn erf(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(TanhOps, B::erf); diff --git a/burn-ndarray/src/ops/base.rs b/burn-ndarray/src/ops/base.rs index 980e459dc8..bb73a02d43 100644 --- a/burn-ndarray/src/ops/base.rs +++ b/burn-ndarray/src/ops/base.rs @@ -181,6 +181,13 @@ where NdArrayTensor { array } } + pub fn recip(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor.array.map(|x| 1.elem::() / *x); + let array = array.into_shared(); + + NdArrayTensor { array } + } + pub fn mean(tensor: NdArrayTensor) -> NdArrayTensor { let data = Data::from([tensor.array.mean().unwrap()]); NdArrayTensor::from_data(data) diff --git a/burn-ndarray/src/ops/tensor.rs b/burn-ndarray/src/ops/tensor.rs index 29b6cd61ad..0ac3f20226 100644 --- a/burn-ndarray/src/ops/tensor.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -127,6 +127,10 @@ impl TensorOps for NdArray { Self::mul_scalar(tensor, (-1f32).elem::()) } + fn recip(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::recip(tensor) + } + fn swap_dims( tensor: NdArrayTensor, dim1: usize, diff --git a/burn-tch/src/ops/tensor.rs b/burn-tch/src/ops/tensor.rs index 0b99df20e4..2326b8a76f 100644 --- a/burn-tch/src/ops/tensor.rs +++ b/burn-tch/src/ops/tensor.rs @@ -172,6 +172,10 @@ impl TensorOps for LibTorch { Self::mul_scalar(tensor, (-1f32).elem::()) } + fn recip(tensor: TchTensor) -> TchTensor { + TchTensor::new(tensor.tensor.reciprocal()) + } + fn swap_dims( tensor: TchTensor, dim1: usize, diff --git a/burn-tensor/src/tensor/api/float.rs b/burn-tensor/src/tensor/api/float.rs index 4cc42c5411..d2cededbab 100644 --- a/burn-tensor/src/tensor/api/float.rs +++ b/burn-tensor/src/tensor/api/float.rs @@ -67,6 +67,11 @@ where Self::new(B::powf(self.primitive, value)) } + /// Applies element wise reciprocal operation. + pub fn recip(self) -> Self { + Self::new(B::recip(self.primitive)) + } + /// Applies element wise root square operation. pub fn sqrt(self) -> Self { Self::new(B::sqrt(self.primitive)) diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 18e2271f3d..63c081df40 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -410,6 +410,9 @@ pub trait TensorOps { Self::mul_scalar(tensor, (-1.0_f32).elem::>()) } + /// Calculates the reciprocals elementwise + fn recip(tensor: FloatTensor) -> FloatTensor; + /// Transposes a tensor. /// /// # Arguments diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index ee0e559fee..1d0cd93807 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -60,6 +60,7 @@ macro_rules! testgen_all { burn_tensor::testgen_one_hot!(); burn_tensor::testgen_powf!(); burn_tensor::testgen_random!(); + burn_tensor::testgen_recip!(); burn_tensor::testgen_repeat!(); burn_tensor::testgen_reshape!(); burn_tensor::testgen_select!(); diff --git a/burn-tensor/src/tests/ops/mod.rs b/burn-tensor/src/tests/ops/mod.rs index 25e5afaef4..cde87af36e 100644 --- a/burn-tensor/src/tests/ops/mod.rs +++ b/burn-tensor/src/tests/ops/mod.rs @@ -28,6 +28,7 @@ mod neg; mod one_hot; mod powf; mod random; +mod recip; mod repeat; mod reshape; mod select; diff --git a/burn-tensor/src/tests/ops/recip.rs b/burn-tensor/src/tests/ops/recip.rs new file mode 100644 index 0000000000..70395fd60b --- /dev/null +++ b/burn-tensor/src/tests/ops/recip.rs @@ -0,0 +1,16 @@ +#[burn_tensor_testgen::testgen(recip)] +mod tests { + use super::*; + use burn_tensor::{Data, Tensor}; + + #[test] + fn should_support_recip_ops() { + let data = Data::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.recip().into_data(); + + let data_expected = Data::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); + data_expected.assert_approx_eq(&data_actual, 3); + } +} diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index 4dbd3c722e..35dcbe12f3 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -510,4 +510,17 @@ where ) -> FloatTensor { kernel::clamp(tensor, min, max) } + + fn recip( + tensor: FloatTensor, D>, + ) -> FloatTensor, D> { + unary!(Recip, func "1.0 /"); + unary_inplace!(RecipInplace, func "1.0 /"); + + if tensor.can_mut() { + return unary_inplace_default::(tensor); + } + + unary_default::(tensor) + } } From 9fef1e5187de807cfb4598c04dac0fa2dac85cda Mon Sep 17 00:00:00 2001 From: Zsombor Gegesy Date: Mon, 13 Nov 2023 22:49:05 +0100 Subject: [PATCH 2/6] Add Reciprocal to the supported ONNX operations --- burn-import/SUPPORTED-ONNX-OPS.md | 2 +- burn-import/src/burn/node/unary.rs | 26 ++++++++++++++++++++++++++ burn-import/src/onnx/dim_inference.rs | 1 + burn-import/src/onnx/to_burn.rs | 8 ++++++++ 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/burn-import/SUPPORTED-ONNX-OPS.md b/burn-import/SUPPORTED-ONNX-OPS.md index d9ff5bcbce..ab0de80b82 100644 --- a/burn-import/SUPPORTED-ONNX-OPS.md +++ b/burn-import/SUPPORTED-ONNX-OPS.md @@ -134,7 +134,7 @@ represent the corresponding Burn Op. | [RandomUniform][128] | ❌ | ✅ | | [RandomUniformLike][129] | ❌ | ✅ | | [Range][130] | ❌ | ✅ | -| [Reciprocal][131] | ❌ | ❌ | +| [Reciprocal][131] | ✅ | ✅ | | [ReduceL][132] | ❌ | ❌ | | [ReduceLogSum][133] | ❌ | ❌ | | [ReduceLogSumExp][134] | ❌ | ❌ | diff --git a/burn-import/src/burn/node/unary.rs b/burn-import/src/burn/node/unary.rs index bbc16c9c3f..7f557b785c 100644 --- a/burn-import/src/burn/node/unary.rs +++ b/burn-import/src/burn/node/unary.rs @@ -26,6 +26,7 @@ pub enum UnaryNodeKind { LogSoftmax, Softmax, Relu, + Reciprocal, Sigmoid, Tanh, Transpose, @@ -40,6 +41,7 @@ impl UnaryNodeKind { Self::LogSoftmax => "log_softmax", Self::Softmax => "softmax", Self::Relu => "relu", + Self::Reciprocal => "reciprocal", Self::Sigmoid => "sigmoid", Self::Tanh => "tanh", Self::Transpose => "transpose", @@ -141,6 +143,11 @@ impl UnaryNode { Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function)) } + pub(crate) fn reciprocal(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.recip() }; + Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function)) + } + /// Casts the input to the output type. /// /// Currently this function only supports the following conversions: @@ -334,6 +341,25 @@ mod tests { ); } + #[test] + fn test_unary_codegen_reciprocal() { + one_node_graph( + UnaryNode::reciprocal( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.recip(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + #[test] fn test_unary_codegen_cast() { one_node_graph( diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index fc71790710..ce2a9ce4f3 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -79,6 +79,7 @@ pub fn dim_inference( NodeType::Erf => same_as_input(node), NodeType::Sqrt => same_as_input(node), NodeType::Tanh => same_as_input(node), + NodeType::Reciprocal => same_as_input(node), NodeType::Softmax => same_as_input(node), NodeType::ReduceMean => mean_update_outputs(node), NodeType::Constant => constant_update_outputs(node), diff --git a/burn-import/src/onnx/to_burn.rs b/burn-import/src/onnx/to_burn.rs index ab7939aa06..4bf912bd44 100644 --- a/burn-import/src/onnx/to_burn.rs +++ b/burn-import/src/onnx/to_burn.rs @@ -250,6 +250,7 @@ impl ONNXGraph { NodeType::Tanh => graph.register(Self::tanh_conversion(node)), NodeType::Constant => graph.register(Self::constant_conversion::(node)), NodeType::Reshape => graph.register(Self::reshape_conversion(node)), + NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), NodeType::Transpose => graph.register(Self::transpose_conversion(node)), NodeType::Concat => graph.register(Self::concat_conversion(node)), @@ -447,6 +448,13 @@ impl ONNXGraph { UnaryNode::sigmoid(input, output) } + fn reciprocal_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::reciprocal(input, output) + } + fn log_softmax_conversion(node: Node) -> UnaryNode { let input = node.inputs.get(0).unwrap().to_type(); let output = node.outputs.get(0).unwrap().to_type(); From aca536d3bd72132f3e3226397f2e4ac07be725ae Mon Sep 17 00:00:00 2001 From: Zsombor Gegesy Date: Mon, 13 Nov 2023 23:00:42 +0100 Subject: [PATCH 3/6] Update burn-book too --- burn-book/src/building-blocks/tensor.md | 1 + 1 file changed, 1 insertion(+) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 829f8cd20f..270a2477c9 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -126,6 +126,7 @@ Those operations are only available for `Float` tensors. | `tensor.erf()` | `tensor.erf()` | | `tensor.powf(value)` | `tensor.pow(value)` | | `tensor.sqrt()` | `tensor.sqrt()` | +| `tensor.recip()` | `tensor.reciprocal()` | | `tensor.cos()` | `tensor.cos()` | | `tensor.sin()` | `tensor.sin()` | | `tensor.tanh()` | `tensor.tanh()` | From 07ed23557ff1236ada4b5d0ef88bb3e38b674856 Mon Sep 17 00:00:00 2001 From: Zsombor Gegesy Date: Tue, 14 Nov 2023 00:37:51 +0100 Subject: [PATCH 4/6] Update the supported onnx ops documentation --- burn-import/SUPPORTED-ONNX-OPS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-import/SUPPORTED-ONNX-OPS.md b/burn-import/SUPPORTED-ONNX-OPS.md index ab0de80b82..7a4c0ebf7c 100644 --- a/burn-import/SUPPORTED-ONNX-OPS.md +++ b/burn-import/SUPPORTED-ONNX-OPS.md @@ -134,7 +134,7 @@ represent the corresponding Burn Op. | [RandomUniform][128] | ❌ | ✅ | | [RandomUniformLike][129] | ❌ | ✅ | | [Range][130] | ❌ | ✅ | -| [Reciprocal][131] | ✅ | ✅ | +| [Reciprocal][131] | ❌ | ✅ | | [ReduceL][132] | ❌ | ❌ | | [ReduceLogSum][133] | ❌ | ❌ | | [ReduceLogSumExp][134] | ❌ | ❌ | From 454d2e74b8de7a2547328f5009550d1da22de841 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Tue, 14 Nov 2023 14:08:07 -0600 Subject: [PATCH 5/6] Add recip ONNX file test --- burn-import/SUPPORTED-ONNX-OPS.md | 2 +- burn-import/onnx-tests/build.rs | 1 + burn-import/onnx-tests/tests/onnx_tests.rs | 14 +++++++ burn-import/onnx-tests/tests/recip/recip.onnx | 17 ++++++++ burn-import/onnx-tests/tests/recip/recip.py | 42 +++++++++++++++++++ 5 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 burn-import/onnx-tests/tests/recip/recip.onnx create mode 100755 burn-import/onnx-tests/tests/recip/recip.py diff --git a/burn-import/SUPPORTED-ONNX-OPS.md b/burn-import/SUPPORTED-ONNX-OPS.md index 7a4c0ebf7c..ab0de80b82 100644 --- a/burn-import/SUPPORTED-ONNX-OPS.md +++ b/burn-import/SUPPORTED-ONNX-OPS.md @@ -134,7 +134,7 @@ represent the corresponding Burn Op. | [RandomUniform][128] | ❌ | ✅ | | [RandomUniformLike][129] | ❌ | ✅ | | [Range][130] | ❌ | ✅ | -| [Reciprocal][131] | ❌ | ✅ | +| [Reciprocal][131] | ✅ | ✅ | | [ReduceL][132] | ❌ | ❌ | | [ReduceLogSum][133] | ❌ | ❌ | | [ReduceLogSumExp][134] | ❌ | ❌ | diff --git a/burn-import/onnx-tests/build.rs b/burn-import/onnx-tests/build.rs index 6f0aa2848a..dc1576099e 100644 --- a/burn-import/onnx-tests/build.rs +++ b/burn-import/onnx-tests/build.rs @@ -27,6 +27,7 @@ fn main() { .input("tests/log_softmax/log_softmax.onnx") .input("tests/maxpool2d/maxpool2d.onnx") .input("tests/mul/mul.onnx") + .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/reshape/reshape.onnx") .input("tests/sigmoid/sigmoid.onnx") diff --git a/burn-import/onnx-tests/tests/onnx_tests.rs b/burn-import/onnx-tests/tests/onnx_tests.rs index 24e9dffa74..f1e515845d 100644 --- a/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/burn-import/onnx-tests/tests/onnx_tests.rs @@ -34,6 +34,7 @@ include_models!( log_softmax, maxpool2d, mul, + recip, relu, reshape, sigmoid, @@ -591,4 +592,17 @@ mod tests { let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]); output.to_data().assert_approx_eq(&expected, 4); } + + #[test] + fn recip() { + // Initialize the model + let model = recip::Model::::new(); + + // Run the model + let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]]); + let output = model.forward(input); + // data from pyTorch + let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]); + output.to_data().assert_approx_eq(&expected, 4); + } } diff --git a/burn-import/onnx-tests/tests/recip/recip.onnx b/burn-import/onnx-tests/tests/recip/recip.onnx new file mode 100644 index 0000000000..6879a1f9a6 --- /dev/null +++ b/burn-import/onnx-tests/tests/recip/recip.onnx @@ -0,0 +1,17 @@ +pytorch2.1.0: +0 +onnx::Reciprocal_01 /Reciprocal" +Reciprocal +main_graphZ, +onnx::Reciprocal_0 + + + + +b +1 + + + + +B \ No newline at end of file diff --git a/burn-import/onnx-tests/tests/recip/recip.py b/burn-import/onnx-tests/tests/recip/recip.py new file mode 100755 index 0000000000..f489af305e --- /dev/null +++ b/burn-import/onnx-tests/tests/recip/recip.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/recip/recip.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + return x.reciprocal() + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "recip.onnx" + dummy_input = torch.randn(1, 2, 3, 4, device=device) + + torch.onnx.export(model, (dummy_input), onnx_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + # Output some test data for use in the test + test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]]) + + print("Test input data: {}".format(test_input)) + output = model.forward(test_input) + print("Test output data: {}".format(output)) + + +if __name__ == '__main__': + main() From b55995e6c322c3aefb40fdd00dae211dfeac95a8 Mon Sep 17 00:00:00 2001 From: Zsombor Gegesy Date: Tue, 14 Nov 2023 23:50:02 +0100 Subject: [PATCH 6/6] Add autodiff test for recip() call, and fix the implementation --- burn-autodiff/src/ops/tensor.rs | 19 ++++++++++++++----- burn-autodiff/src/tests/mod.rs | 2 ++ burn-autodiff/src/tests/recip.rs | 20 ++++++++++++++++++++ burn-candle/src/lib.rs | 1 + 4 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 burn-autodiff/src/tests/recip.rs diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index 0c2e71f50e..e4cd4755e7 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -440,16 +440,25 @@ impl TensorOps for Autodiff { struct Recip; impl Backward for Recip { - type State = (); + type State = B::TensorPrimitive; fn backward(self, ops: Ops, grads: &mut Gradients) { - unary::(ops.parents, ops.node, grads, |grad| B::recip(grad)); + let tensor = ops.state; + unary::(ops.parents, ops.node, grads, |grad| { + let tmp = B::powf(tensor, -2.0); + let value = B::neg(tmp); + + B::mul(grad, value) + }); } } - Recip - .prepare([tensor.node], [tensor.graph]) - .stateless(B::recip(tensor.primitive)) + match Recip.prepare([tensor.node], [tensor.graph]).stateful() { + OpsKind::Tracked(prep) => { + prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive)) + } + OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)), + } } fn swap_dims( diff --git a/burn-autodiff/src/tests/mod.rs b/burn-autodiff/src/tests/mod.rs index fd07148738..23e84426d0 100644 --- a/burn-autodiff/src/tests/mod.rs +++ b/burn-autodiff/src/tests/mod.rs @@ -34,6 +34,7 @@ mod mul; mod multithread; mod neg; mod pow; +mod recip; mod relu; mod reshape; mod select; @@ -94,6 +95,7 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_mul!(); burn_autodiff::testgen_ad_neg!(); burn_autodiff::testgen_ad_powf!(); + burn_autodiff::testgen_ad_recip!(); burn_autodiff::testgen_ad_reshape!(); burn_autodiff::testgen_ad_sin!(); burn_autodiff::testgen_ad_softmax!(); diff --git a/burn-autodiff/src/tests/recip.rs b/burn-autodiff/src/tests/recip.rs new file mode 100644 index 0000000000..c77579e273 --- /dev/null +++ b/burn-autodiff/src/tests/recip.rs @@ -0,0 +1,20 @@ +#[burn_tensor_testgen::testgen(ad_recip)] +mod tests { + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_recip() { + let data = Data::from([2.0, 5.0, 0.4]); + + let tensor = TestAutodiffTensor::from_data(data).require_grad(); + let tensor_out = tensor.clone().recip(); + + let grads = tensor_out.backward(); + let grad = tensor.grad(&grads).unwrap(); + + assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5])); + grad.to_data() + .assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3); + } +} diff --git a/burn-candle/src/lib.rs b/burn-candle/src/lib.rs index c3b1b07a44..a0ee7f0490 100644 --- a/burn-candle/src/lib.rs +++ b/burn-candle/src/lib.rs @@ -134,6 +134,7 @@ mod tests { burn_autodiff::testgen_ad_mul!(); burn_autodiff::testgen_ad_neg!(); burn_autodiff::testgen_ad_powf!(); + burn_autodiff::testgen_ad_recip!(); burn_autodiff::testgen_ad_reshape!(); burn_autodiff::testgen_ad_sin!(); burn_autodiff::testgen_ad_softmax!();