From d3536f7d2d390d84805ce25c63fddc074bbc9d6d Mon Sep 17 00:00:00 2001 From: Tommy van der Vorst Date: Tue, 23 May 2023 00:03:39 +0200 Subject: [PATCH] chore: document the builder functions --- wonnx/src/builder.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wonnx/src/builder.rs b/wonnx/src/builder.rs index 2fd25ff0..082cd65c 100644 --- a/wonnx/src/builder.rs +++ b/wonnx/src/builder.rs @@ -26,11 +26,13 @@ impl<'model> From<&TensorRef<'model>> for Input<'model> { } impl<'model> TensorRef<'model> { + /// Element-wise addition of two tensors (ONNX Add operator) pub fn add(&self, rhs: &Self) -> Self { assert_eq!(self.output_shape, rhs.output_shape); self.binary_op(rhs, "Add", self.output_shape.clone()) } + /// Element-wise negation of this tensor (ONNX Neg operator) pub fn neg(&self) -> Self { self.unary_mapping_op("Neg") } @@ -54,6 +56,7 @@ impl<'model> TensorRef<'model> { } } + /// Convolution (ONNX Conv operator) pub fn conv(&self, weights: &Self, kernel_shape: &[usize], output_dims: &[usize]) -> Self { let output_shape = Shape::from(self.output_shape.data_type, output_dims); let mut op = OperatorDefinition::new( @@ -95,6 +98,7 @@ impl<'model> TensorRef<'model> { } } +/// Create a tensor reference representing input supplied at inference time. pub fn input<'model, S: ToString>( name: S, scalar_type: ScalarType, @@ -114,6 +118,7 @@ pub fn input<'model, S: ToString>( } } +/// Create a tensor reference containing static data (included in the model). pub fn tensor<'model, S: ToString>( name: S, dims: &[usize],