From 152509c3787dec1c4dd7d35b1b2bb8f177da0f64 Mon Sep 17 00:00:00 2001 From: Arjun31415 Date: Sun, 5 May 2024 00:15:42 +0530 Subject: [PATCH] PReLu ONNX import (#1721) * added prelu onnx operator * bug fix * added onnx tests and burn codegen tests * fix tests * added prelu to supported onnx ops and add prelu to dim_inference --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 24 +++ .../onnx-tests/tests/prelu/prelu.onnx | Bin 0 -> 172 bytes .../onnx-tests/tests/prelu/prelu.py | 49 ++++++ crates/burn-import/src/burn/node/base.rs | 4 + crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/burn/node/prelu.rs | 151 ++++++++++++++++++ crates/burn-import/src/onnx/dim_inference.rs | 1 + .../burn-import/src/onnx/op_configuration.rs | 1 - crates/burn-import/src/onnx/to_burn.rs | 11 ++ 11 files changed, 243 insertions(+), 2 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/prelu/prelu.onnx create mode 100644 crates/burn-import/onnx-tests/tests/prelu/prelu.py create mode 100644 crates/burn-import/src/burn/node/prelu.rs diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index f976bafae2..c884bee82e 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -126,7 +126,7 @@ represent the corresponding Burn Op. | [Or][119] | ❌ | ❌ | | [Pad][120] | ❌ | ✅ | | [Pow][121] | ✅ | ✅ | -| [PRelu][122] | ❌ | ✅ | +| [PRelu][122] | ✅ | ✅ | | [QLinearConv][123] | ❌ | ❌ | | [QLinearMatMul][124] | ❌ | ❌ | | [QuantizeLinear][125] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 94b32209e9..3a2c9b685d 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -39,6 +39,7 @@ fn main() { .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx") + .input("tests/prelu/prelu.onnx") .input("tests/reduce_max/reduce_max.onnx") .input("tests/reduce_mean/reduce_mean.onnx") .input("tests/reshape/reshape.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 34ddfa5f87..c237542d41 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -47,6 +47,7 @@ include_models!( mul, neg, not, + prelu, recip, reduce_max, reduce_mean, @@ -658,6 +659,29 @@ mod tests { assert_eq!(output.to_data(), expected); } + #[test] + fn prelu() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: prelu::Model = prelu::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [ + [0.33669037, 0.0, 0.23446237], + [0.23033303, -1.122_856, -0.18632829], + ], + &device, + ); + let output = model.forward(input); + let expected = Data::from([ + [0.33669037, 0.0, 0.23446237], + [0.23033303, -0.280714, -0.046582073], + ]); + + assert_eq!(output.to_data(), expected); + } + #[test] fn relu() { // Initialize the model without weights (because the exported file does not contain them) diff --git a/crates/burn-import/onnx-tests/tests/prelu/prelu.onnx b/crates/burn-import/onnx-tests/tests/prelu/prelu.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d9644b84e5928e285e0e8db292545ab131abf537 GIT binary patch literal 172 zcmdtc@ { Conv1d(Conv1dNode), Conv2d(Conv2dNode), ConvTranspose2d(ConvTranspose2dNode), + PRelu(PReluNode), Dropout(DropoutNode), Gather(GatherNode), GlobalAvgPool(GlobalAvgPoolNode), @@ -111,6 +113,7 @@ macro_rules! match_all { Node::Conv1d(node) => $func(node), Node::Conv2d(node) => $func(node), Node::ConvTranspose2d(node) => $func(node), + Node::PRelu(node) => $func(node), Node::Dropout(node) => $func(node), Node::Gather(node) => $func(node), Node::GlobalAvgPool(node) => $func(node), @@ -147,6 +150,7 @@ impl Node { Node::Conv1d(_) => "conv1d", Node::Conv2d(_) => "conv2d", Node::ConvTranspose2d(_) => "conv_transpose2d", + Node::PRelu(_) => "prelu", Node::Dropout(_) => "dropout", Node::Gather(_) => "gather", Node::GlobalAvgPool(_) => "global_avg_pool", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 965652c7a2..ae936bbadb 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -17,6 +17,7 @@ pub(crate) mod linear; pub(crate) mod mask_where; pub(crate) mod matmul; pub(crate) mod max_pool2d; +pub(crate) mod prelu; pub(crate) mod reshape; pub(crate) mod unary; pub(crate) mod unsqueeze; diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs new file mode 100644 index 0000000000..a1e474d7d9 --- /dev/null +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -0,0 +1,151 @@ +use super::{Node, NodeCodegen, SerializationBackend}; +use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type}; +use burn::{ + module::{Param, ParamId}, + nn::{PReluConfig, PReluRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, +}; +use proc_macro2::TokenStream; +use quote::quote; +use serde::Serialize; + +#[derive(Clone, Debug)] +pub struct PReluNode { + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub alpha: DataSerialize, + pub config: PReluConfig, +} + +impl PReluNode { + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + alpha: DataSerialize, + config: PReluConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + PRelu + }, + ), + input, + output, + alpha, + config, + } + } +} + +impl NodeCodegen for PReluNode { + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self) -> Option { + let name = &self.field.name; + let tokens = quote! { + let #name = PReluConfig::new() + .init(device); + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let device = Default::default(); + let record = PReluRecord:: { + alpha: Param::initialized( + ParamId::new(), + Tensor::from_data(self.alpha.clone().convert(), &device), + ), + }; + + let item = Record::into_item::(record); + item.serialize(serializer) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + + quote! { + let #output = self.#field.forward(#input); + } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PRelu"); + imports.register("burn::nn::PReluConfig"); + } + + fn into_node(self) -> Node { + Node::PRelu(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType}; + use burn::{record::FullPrecisionSettings, tensor::Data}; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(PReluNode::new( + "prelu", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + PReluConfig::new(), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::nn::PRelu; + use burn::nn::PReluConfig; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + #[derive(Module, Debug)] + pub struct Model { + prelu: PRelu, + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + let prelu = PReluConfig::new().init(device); + Self { + prelu, + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.prelu.forward(input); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 0775e631b8..bc5f75a662 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -59,6 +59,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Unsqueeze => unsqueeze_update_output(node), NodeType::Pow => same_as_input(node), NodeType::LeakyRelu => same_as_input(node), + NodeType::PRelu => same_as_input(node), NodeType::Where => where_update_outputs(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. _ => temporary_pass_through_stub(node), diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index b5750bce02..1416e55230 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -120,7 +120,6 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { .with_padding(padding) .with_dilation([dilations[0] as usize, dilations[1] as usize]) } - pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let mut attrs = curr.attrs.clone(); let kernel_shape = attrs diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index a736b564b5..44b328492a 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -5,6 +5,7 @@ use std::{ }; use burn::{ + nn::PReluConfig, record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings}, tensor::{DataSerialize, Element}, }; @@ -30,6 +31,7 @@ use crate::{ mask_where::WhereNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, + prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, @@ -238,6 +240,7 @@ impl OnnxGraph { NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), NodeType::Conv2d => graph.register(Self::conv2d_conversion::(node)), NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), + NodeType::PRelu => graph.register(Self::prelu_conversion::(node)), NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), NodeType::MatMul => graph.register(Self::matmul_conversion(node)), NodeType::Neg => graph.register(Self::neg_conversion(node)), @@ -701,6 +704,14 @@ impl OnnxGraph { MaxPool2dNode::new(name, input, output, config) } + fn prelu_conversion(node: Node) -> PReluNode { + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); + let weight = extract_data_serialize::(1, &node).unwrap(); + let config = PReluConfig::new(); + let name = &node.name; + PReluNode::::new(name, input, output, weight, config) + } fn conv_transpose2d_conversion(node: Node) -> ConvTranspose2dNode { let input = node.inputs.first().unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type();