Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tensor.recip() function to calculate elementwise reciprocals #953

Merged
merged 6 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,32 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
.stateless(B::neg(tensor.primitive))
}

fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Recip;

impl<B: Backend, const D: usize> Backward<B, D, 1> for Recip {
type State = B::TensorPrimitive<D>;

fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let tensor = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let tmp = B::powf(tensor, -2.0);
let value = B::neg(tmp);

B::mul(grad, value)
});
}
}

match Recip.prepare([tensor.node], [tensor.graph]).stateful() {
OpsKind::Tracked(prep) => {
prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive))
}
OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)),
}
}

fn swap_dims<const D: usize>(
tensor: FloatTensor<Self, D>,
dim1: usize,
Expand Down
2 changes: 2 additions & 0 deletions burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod mul;
mod multithread;
mod neg;
mod pow;
mod recip;
mod relu;
mod reshape;
mod select;
Expand Down Expand Up @@ -94,6 +95,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_recip!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();
Expand Down
20 changes: 20 additions & 0 deletions burn-autodiff/src/tests/recip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#[burn_tensor_testgen::testgen(ad_recip)]
mod tests {
use super::*;
use burn_tensor::Data;

#[test]
fn should_diff_recip() {
let data = Data::from([2.0, 5.0, 0.4]);

let tensor = TestAutodiffTensor::from_data(data).require_grad();
let tensor_out = tensor.clone().recip();

let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();

assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5]));
grad.to_data()
.assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3);
}
}
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ Those operations are only available for `Float` tensors.
| `tensor.erf()` | `tensor.erf()` |
| `tensor.powf(value)` | `tensor.pow(value)` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.tanh()` | `tensor.tanh()` |
Expand Down
2 changes: 2 additions & 0 deletions burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ mod tests {
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_recip!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
// burn_tensor::testgen_div!();
Expand Down Expand Up @@ -133,6 +134,7 @@ mod tests {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_recip!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();
Expand Down
4 changes: 4 additions & 0 deletions burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,4 +442,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.clamp(min, max).unwrap())
}

fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}
}
7 changes: 7 additions & 0 deletions burn-fusion/src/graph/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ pub enum FloatOpsDescription<B: FusionBackend> {
(TensorDescription, Distribution<FloatElem<B>>),
Box<dyn Ops<B, Args = (TensorDescription, Distribution<FloatElem<B>>)>>,
),
/// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip).
Recip(
UnaryOpsDescription,
Box<dyn Ops<B, Args = UnaryOpsDescription>>,
),
}

/// Operation description specific to module.
Expand Down Expand Up @@ -1252,6 +1257,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs),
FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input),
Expand All @@ -1268,6 +1274,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles),
Expand Down
15 changes: 15 additions & 0 deletions burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,21 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}

fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Recip, B::recip);

let out = tensor.client.create_tensor_empty(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip(
UnaryOpsDescription {
input: tensor.into_description(),
out: out.to_description_out(),
},
Box::new(Recip::<D>),
)));
out
}

fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::erf);

Expand Down
2 changes: 1 addition & 1 deletion burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ represent the corresponding Burn Op.
| [RandomUniform][128] | ❌ | ✅ |
| [RandomUniformLike][129] | ❌ | ✅ |
| [Range][130] | ❌ | ✅ |
| [Reciprocal][131] | | |
| [Reciprocal][131] | | |
gzsombor marked this conversation as resolved.
Show resolved Hide resolved
| [ReduceL][132] | ❌ | ❌ |
| [ReduceLogSum][133] | ❌ | ❌ |
| [ReduceLogSumExp][134] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn main() {
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
Expand Down
14 changes: 14 additions & 0 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include_models!(
log_softmax,
maxpool2d,
mul,
recip,
relu,
reshape,
sigmoid,
Expand Down Expand Up @@ -591,4 +592,17 @@ mod tests {
let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}

#[test]
fn recip() {
// Initialize the model
let model = recip::Model::<Backend>::new();

// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
let output = model.forward(input);
// data from pyTorch
let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}
}
17 changes: 17 additions & 0 deletions burn-import/onnx-tests/tests/recip/recip.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pytorch2.1.0:�
0
onnx::Reciprocal_01 /Reciprocal"
Reciprocal
main_graphZ,
onnx::Reciprocal_0




b
1




B
42 changes: 42 additions & 0 deletions burn-import/onnx-tests/tests/recip/recip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/recip/recip.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return x.reciprocal()


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "recip.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)

torch.onnx.export(model, (dummy_input), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])

print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
26 changes: 26 additions & 0 deletions burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub enum UnaryNodeKind {
LogSoftmax,
Softmax,
Relu,
Reciprocal,
Sigmoid,
Tanh,
Transpose,
Expand All @@ -40,6 +41,7 @@ impl UnaryNodeKind {
Self::LogSoftmax => "log_softmax",
Self::Softmax => "softmax",
Self::Relu => "relu",
Self::Reciprocal => "reciprocal",
Self::Sigmoid => "sigmoid",
Self::Tanh => "tanh",
Self::Transpose => "transpose",
Expand Down Expand Up @@ -141,6 +143,11 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function))
}

pub(crate) fn reciprocal(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.recip() };
Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function))
}

/// Casts the input to the output type.
///
/// Currently this function only supports the following conversions:
Expand Down Expand Up @@ -334,6 +341,25 @@ mod tests {
);
}

#[test]
fn test_unary_codegen_reciprocal() {
one_node_graph(
UnaryNode::reciprocal(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.recip();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}

#[test]
fn test_unary_codegen_cast() {
one_node_graph(
Expand Down
1 change: 1 addition & 0 deletions burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub fn dim_inference(
NodeType::Erf => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::Softmax => same_as_input(node),
NodeType::ReduceMean => mean_update_outputs(node),
NodeType::Constant => constant_update_outputs(node),
Expand Down
8 changes: 8 additions & 0 deletions burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ impl ONNXGraph {
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
NodeType::Concat => graph.register(Self::concat_conversion(node)),
Expand Down Expand Up @@ -447,6 +448,13 @@ impl ONNXGraph {
UnaryNode::sigmoid(input, output)
}

fn reciprocal_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();

UnaryNode::reciprocal(input, output)
}

fn log_softmax_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
Expand Down
7 changes: 7 additions & 0 deletions burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ where
NdArrayTensor { array }
}

pub fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.map(|x| 1.elem::<E>() / *x);
let array = array.into_shared();

NdArrayTensor { array }
}

pub fn mean<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
let data = Data::from([tensor.array.mean().unwrap()]);
NdArrayTensor::from_data(data)
Expand Down
4 changes: 4 additions & 0 deletions burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
Self::mul_scalar(tensor, (-1f32).elem::<E>())
}

fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
NdArrayMathOps::recip(tensor)
}

fn swap_dims<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim1: usize,
Expand Down
4 changes: 4 additions & 0 deletions burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
Self::mul_scalar(tensor, (-1f32).elem::<E>())
}

fn recip<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.reciprocal())
}

fn swap_dims<const D: usize>(
tensor: TchTensor<E, D>,
dim1: usize,
Expand Down
5 changes: 5 additions & 0 deletions burn-tensor/src/tensor/api/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ where
Self::new(B::powf(self.primitive, value))
}

/// Applies element wise reciprocal operation.
pub fn recip(self) -> Self {
Self::new(B::recip(self.primitive))
}

/// Applies element wise root square operation.
pub fn sqrt(self) -> Self {
Self::new(B::sqrt(self.primitive))
Expand Down
3 changes: 3 additions & 0 deletions burn-tensor/src/tensor/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ pub trait TensorOps<B: Backend> {
Self::mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
}

/// Calculates the reciprocals elementwise
fn recip<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;

/// Transposes a tensor.
///
/// # Arguments
Expand Down
Loading