Skip to content

Commit

Permalink
PReLu ONNX import (#1721)
Browse files Browse the repository at this point in the history
* added prelu onnx operator

* bug fix

* added onnx tests and burn codegen tests

* fix tests

* added prelu to supported onnx ops and add prelu to dim_inference
  • Loading branch information
Arjun31415 committed May 4, 2024
1 parent a8661a2 commit 152509c
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 2 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ represent the corresponding Burn Op.
| [Or][119] |||
| [Pad][120] |||
| [Pow][121] |||
| [PRelu][122] | ||
| [PRelu][122] | ||
| [QLinearConv][123] |||
| [QLinearMatMul][124] |||
| [QuantizeLinear][125] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ fn main() {
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/prelu/prelu.onnx")
.input("tests/reduce_max/reduce_max.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reshape/reshape.onnx")
Expand Down
24 changes: 24 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include_models!(
mul,
neg,
not,
prelu,
recip,
reduce_max,
reduce_mean,
Expand Down Expand Up @@ -658,6 +659,29 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn prelu() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: prelu::Model<Backend> = prelu::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 2>::from_floats(
[
[0.33669037, 0.0, 0.23446237],
[0.23033303, -1.122_856, -0.18632829],
],
&device,
);
let output = model.forward(input);
let expected = Data::from([
[0.33669037, 0.0, 0.23446237],
[0.23033303, -0.280714, -0.046582073],
]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn relu() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
Binary file added crates/burn-import/onnx-tests/tests/prelu/prelu.onnx
Binary file not shown.
49 changes: 49 additions & 0 deletions crates/burn-import/onnx-tests/tests/prelu/prelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# used to generate model: prelu.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.relu1 = nn.PReLU()

def forward(self, x):
x = self.relu1(x)
return x


def main():

# Set seed for reproducibility
torch.manual_seed(42)

torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "prelu.onnx"
test_input = torch.randn(2, 3, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data of ones: {}".format(test_input))
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

print("Test output: {}".format(output))


if __name__ == '__main__':
main()

4 changes: 4 additions & 0 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::layer_norm::LayerNormNode;
use super::mask_where::WhereNode;
use super::prelu::PReluNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
Expand Down Expand Up @@ -85,6 +86,7 @@ pub enum Node<PS: PrecisionSettings> {
Conv1d(Conv1dNode<PS>),
Conv2d(Conv2dNode<PS>),
ConvTranspose2d(ConvTranspose2dNode<PS>),
PRelu(PReluNode<PS>),
Dropout(DropoutNode),
Gather(GatherNode),
GlobalAvgPool(GlobalAvgPoolNode),
Expand All @@ -111,6 +113,7 @@ macro_rules! match_all {
Node::Conv1d(node) => $func(node),
Node::Conv2d(node) => $func(node),
Node::ConvTranspose2d(node) => $func(node),
Node::PRelu(node) => $func(node),
Node::Dropout(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Expand Down Expand Up @@ -147,6 +150,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Conv1d(_) => "conv1d",
Node::Conv2d(_) => "conv2d",
Node::ConvTranspose2d(_) => "conv_transpose2d",
Node::PRelu(_) => "prelu",
Node::Dropout(_) => "dropout",
Node::Gather(_) => "gather",
Node::GlobalAvgPool(_) => "global_avg_pool",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod reshape;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
Expand Down
151 changes: 151 additions & 0 deletions crates/burn-import/src/burn/node/prelu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type};
use burn::{
module::{Param, ParamId},
nn::{PReluConfig, PReluRecord},
record::{PrecisionSettings, Record},
tensor::{DataSerialize, Tensor},
};
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;

#[derive(Clone, Debug)]
pub struct PReluNode<PS: PrecisionSettings> {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub alpha: DataSerialize<PS::FloatElem>,
pub config: PReluConfig,
}

impl<PS: PrecisionSettings> PReluNode<PS> {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
alpha: DataSerialize<PS::FloatElem>,
config: PReluConfig,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
PRelu<B>
},
),
input,
output,
alpha,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let tokens = quote! {
let #name = PReluConfig::new()
.init(device);
};

Some(tokens)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let device = Default::default();
let record = PReluRecord::<SerializationBackend> {
alpha: Param::initialized(
ParamId::new(),
Tensor::from_data(self.alpha.clone().convert(), &device),
),
};

let item = Record::into_item::<PS>(record);
item.serialize(serializer)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

quote! {
let #output = self.#field.forward(#input);
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PRelu");
imports.register("burn::nn::PReluConfig");
}

fn into_node(self) -> Node<PS> {
Node::PRelu(self)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{record::FullPrecisionSettings, tensor::Data};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(PReluNode::new(
"prelu",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
Data::from([2.]).serialize(),
PReluConfig::new(),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::nn::PRelu;
use burn::nn::PReluConfig;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
prelu: PRelu<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let prelu = PReluConfig::new().init(device);
Self {
prelu,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.prelu.forward(input);
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Unsqueeze => unsqueeze_update_output(node),
NodeType::Pow => same_as_input(node),
NodeType::LeakyRelu => same_as_input(node),
NodeType::PRelu => same_as_input(node),
NodeType::Where => where_update_outputs(node),
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
_ => temporary_pass_through_stub(node),
Expand Down
1 change: 0 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
.with_padding(padding)
.with_dilation([dilations[0] as usize, dilations[1] as usize])
}

pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
let mut attrs = curr.attrs.clone();
let kernel_shape = attrs
Expand Down
11 changes: 11 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
};

use burn::{
nn::PReluConfig,
record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings},
tensor::{DataSerialize, Element},
};
Expand All @@ -30,6 +31,7 @@ use crate::{
mask_where::WhereNode,
matmul::MatmulNode,
max_pool2d::MaxPool2dNode,
prelu::PReluNode,
reshape::ReshapeNode,
unary::UnaryNode,
unsqueeze::UnsqueezeNode,
Expand Down Expand Up @@ -238,6 +240,7 @@ impl OnnxGraph {
NodeType::Conv1d => graph.register(Self::conv1d_conversion::<PS>(node)),
NodeType::Conv2d => graph.register(Self::conv2d_conversion::<PS>(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
NodeType::Neg => graph.register(Self::neg_conversion(node)),
Expand Down Expand Up @@ -701,6 +704,14 @@ impl OnnxGraph {
MaxPool2dNode::new(name, input, output, config)
}

fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode<PS> {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let config = PReluConfig::new();
let name = &node.name;
PReluNode::<PS>::new(name, input, output, weight, config)
}
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode<PS> {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
Expand Down

0 comments on commit 152509c

Please sign in to comment.