Skip to content

Commit

Permalink
Squeeze Onnx Import (#1753)
Browse files Browse the repository at this point in the history
  • Loading branch information
agelas committed May 17, 2024
1 parent 8de05e1 commit 9c5b07c
Show file tree
Hide file tree
Showing 15 changed files with 274 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ represent the corresponding Burn Op.
| [Split][173] |||
| [SplitToSequence][174] |||
| [Sqrt][175] |||
| [Squeeze][176] | ||
| [Squeeze][176] | ||
| [STFT][177] |||
| [StringNormalizer][178] |||
| [Sub][179] |||
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ fn main() {
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
.input("tests/mask_where/mask_where.onnx")
.input("tests/squeeze/squeeze_opset16.onnx")
.input("tests/squeeze/squeeze_opset13.onnx")
.out_dir("model/")
.run_from_script();

Expand Down
26 changes: 25 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ include_models!(
pow_int,
unsqueeze,
unsqueeze_opset16,
unsqueeze_opset11
unsqueeze_opset11,
squeeze_opset16,
squeeze_opset13
);

#[cfg(test)]
Expand Down Expand Up @@ -1295,4 +1297,26 @@ mod tests {

output.to_data().assert_approx_eq(&expected, 4);
}

#[test]
fn squeeze_opset16() {
let device = Default::default();
let model = squeeze_opset16::Model::<Backend>::new(&device);
let input_shape = Shape::from([3, 4, 1, 5]);
let expected_shape = Shape::from([3, 4, 5]);
let input = Tensor::ones(input_shape, &device);
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
}

#[test]
fn squeeze_opset13() {
let device = Default::default();
let model = squeeze_opset13::Model::<Backend>::new(&device);
let input_shape = Shape::from([3, 4, 1, 5]);
let expected_shape = Shape::from([3, 4, 5]);
let input = Tensor::ones(input_shape, &device);
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
}
}
Binary file not shown.
48 changes: 48 additions & 0 deletions crates/burn-import/onnx-tests/tests/squeeze/squeeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

# used to generate model: squeeze.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.axis = 2

def forward(self, x):
x = torch.squeeze(x, self.axis)
return x


def main():
# Set seed for reproducibility
torch.manual_seed(42)

torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

test_input = torch.randn(3, 4, 1, 5, device=device)
model = Model()

# Export to ONNX
torch.onnx.export(model, test_input, "squeeze_opset16.onnx", verbose=False, opset_version=16)
torch.onnx.export(model, test_input, "squeeze_opset13.onnx", verbose=False, opset_version=13)

print(f"Finished exporting model to 16 and 13")

# Output some test data for use in the test
output = model(test_input)
print(f"Test input data: {test_input}")
print(f"Test input data shape: {test_input.shape}")
print(f"Test output data shape: {output.shape}")
print(f"Test output: {output}")


if __name__ == "__main__":
main()
Binary file not shown.
Binary file not shown.
5 changes: 4 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -95,6 +95,7 @@ pub enum Node<PS: PrecisionSettings> {
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Reshape(ReshapeNode),
Squeeze(SqueezeNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
Expand Down Expand Up @@ -124,6 +125,7 @@ macro_rules! match_all {
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Expand Down Expand Up @@ -163,6 +165,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Reshape(_) => "reshape",
Node::Squeeze(_) => "squeeze",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod reshape;
pub(crate) mod squeeze;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
pub(crate) use base::*;
Expand Down
92 changes: 92 additions & 0 deletions crates/burn-import/src/burn/node/squeeze.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct SqueezeNode {
pub input: TensorType,
pub output: TensorType,
pub axes: Vec<i64>,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for SqueezeNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

let axis = &self.axes.first().unwrap().to_tokens();

quote! {
let #output = #input.squeeze(#axis);
}
}

fn into_node(self) -> Node<PS> {
Node::Squeeze(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{squeeze::SqueezeNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(SqueezeNode::new(
TensorType::new_float("tensor1", 3),
TensorType::new_float("tensor2", 2),
[1].into(),
));

graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 3>) -> Tensor<B, 2> {
let tensor2 = tensor1.squeeze(1);
tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
41 changes: 41 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::LeakyRelu => same_as_input(node),
NodeType::PRelu => same_as_input(node),
NodeType::Where => where_update_outputs(node),
NodeType::Squeeze => squeeze_update_output(node),
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
_ => temporary_pass_through_stub(node),
}
Expand Down Expand Up @@ -266,6 +267,46 @@ fn reduce_mean_update_outputs(node: &mut Node) {
}
}

/// Update the output tensor dimension
fn squeeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(axes) => Some(axes.clone()),
_ => panic!("Squeeze: invalid input types"),
},
None => None,
}
} else {
node.attrs.get("axes").cloned().map(|v| v.into_i64s())
};

if axes.is_none() {
panic!("Squeeze must specify an axis");
} else if axes.as_ref().unwrap().len() > 1 {
panic!(
"Squeeze must specify only 1 axis, found {:?}",
axes.as_ref().unwrap().len()
);
}

let input_dim = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.dim,
_ => panic!("Squeeze: invalid input type"),
};

let output_elem = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.elem_type.clone(),
_ => panic!("Squeeze: invalid output type"),
};

node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input_dim - 1,
shape: None, // shape is tracked and calculated at runtime
elem_type: output_elem,
});
}

/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/from_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use super::ir::{ArgType, Argument, Node, NodeType};

use protobuf::Message;

const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 8] = [
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 9] = [
NodeType::BatchNormalization,
NodeType::Clip,
NodeType::Conv1d,
Expand All @@ -26,6 +26,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 8] = [
NodeType::Reshape,
NodeType::Unsqueeze,
NodeType::ReduceSum,
NodeType::Squeeze,
];

#[derive(Debug)]
Expand Down
47 changes: 47 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,50 @@ pub fn transpose_config(curr: &Node) -> Vec<i64> {

perm
}

pub fn squeeze_config(curr: &Node) -> Vec<i64> {
let mut axes = curr
.attrs
.iter()
.filter_map(|(key, value)| {
if key == "axes" {
Some(value.clone().into_i64s())
} else {
None
}
})
.next()
.unwrap_or_else(Vec::new);

// If axes are not found in attributes, try to extract them from input tensor
if axes.is_empty() {
assert!(!curr.inputs.is_empty(), "Squeeze: input must be present");

let input_value = &curr.inputs[1];
match &input_value.ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Squeeze: axes tensor must be 1D");
if let Some(Data::Int64s(data)) = &input_value.value {
axes.clone_from(data)
} else {
panic!("Squeeze: Tensor data type must be int64");
}
}
_ => panic!("Squeeze: Argument for axes must be a tensor"),
}
}

let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Adjust negative axes
axes.iter_mut().for_each(|x| {
if *x < 0 {
*x += tensor.dim as i64;
}
});

axes
}
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::{
max_pool2d::MaxPool2dNode,
prelu::PReluNode,
reshape::ReshapeNode,
squeeze::SqueezeNode,
unary::UnaryNode,
unsqueeze::UnsqueezeNode,
},
Expand Down Expand Up @@ -289,6 +290,7 @@ impl OnnxGraph {
NodeType::Unsqueeze => graph.register(Self::unsqueeze_conversion(node)),
NodeType::Where => graph.register(Self::where_conversion(node)),
NodeType::Sign => graph.register(Self::sign_conversion(node)),
NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
node_type => unsupported_ops.push(node_type),
}
}
Expand Down Expand Up @@ -825,6 +827,14 @@ impl OnnxGraph {
let output = node.outputs.first().unwrap().to_type();
UnaryNode::sign(input, output)
}

fn squeeze_conversion(node: Node) -> SqueezeNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let axes = squeeze_config(&node);

SqueezeNode::new(input, output, axes)
}
}

/// Extract data from node states and convert it to `DataSerialize`.
Expand Down
Loading

0 comments on commit 9c5b07c

Please sign in to comment.