Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added reduce min onnx import #1894

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ represent the corresponding Burn Op.
| [ReduceLogSumExp][134] | ❌ | ❌ |
| [ReduceMax][135] | ✅ | ✅ |
| [ReduceMean][136] | ✅ | ✅ |
| [ReduceMin][137] | | ✅ |
| [ReduceMin][137] | | ✅ |
| [ReduceProd][138] | ❌ | ✅ |
| [ReduceSum][139] | ✅ | ✅ |
| [ReduceSumSquare][140] | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ fn main() {
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/prelu/prelu.onnx")
.input("tests/reduce_max/reduce_max.onnx")
.input("tests/reduce_min/reduce_min.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
Expand Down
17 changes: 17 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ include_models!(
range,
recip,
reduce_max,
reduce_min,
reduce_mean,
reduce_sum_opset13,
reduce_sum_opset11,
Expand Down Expand Up @@ -728,6 +729,22 @@ mod tests {
assert_eq!(output_value.to_data(), expected);
}

#[test]
fn reduce_min() {
let device = Default::default();
let model: reduce_min::Model<Backend> = reduce_min::Model::new(&device);

// Run the models
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
let expected_scalar = Data::from([1.]);
let expected = Data::from([[[[1.]]]]);

assert_eq!(output_scalar.to_data(), expected_scalar);
assert_eq!(output_tensor.to_data(), input.to_data());
assert_eq!(output_value.to_data(), expected);
}

#[test]
fn reduce_mean() {
let device = Default::default();
Expand Down
Binary file not shown.
47 changes: 47 additions & 0 deletions crates/burn-import/onnx-tests/tests/reduce_min/reduce_min.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/reduce_min/reduce_min.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return (
# ReduceMin, keepdims=0, axes=None
torch.min(x),
# ReduceMin, keepdims=1, axes=[1]
torch.min(x, dim=1, keepdim=True).values,
# ReduceMin, keepdims=1, axes=[-1]
torch.min(x, dim=-1, keepdim=True).values,
)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "reduce_min.onnx"
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)

torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)

print(f"Finished exporting model to {onnx_name}")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")


if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions crates/burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub enum UnaryNodeKind {
Neg,
Not,
ReduceMax,
ReduceMin,
ReduceMean,
ReduceSum,
Reciprocal,
Expand Down Expand Up @@ -62,6 +63,7 @@ impl UnaryNodeKind {
Self::Neg => "neg",
Self::Not => "not",
Self::ReduceMax => "reduce_max",
Self::ReduceMin => "reduce_min",
Self::ReduceMean => "reduce_mean",
Self::ReduceSum => "reduce_sum",
Self::Reciprocal => "reciprocal",
Expand Down Expand Up @@ -331,6 +333,35 @@ impl UnaryNode {
}
}

pub(crate) fn reduce_min(input: Type, output: Type, dim: Option<usize>) -> Self {
if let Type::Tensor(ref tensor) = output {
if let Some(dim) = dim {
if tensor.kind == TensorKind::Bool {
// Min is only implemented on numeric tensors
panic!("ReduceMin is not supported for boolean");
}
// ReduceMin, keepdims=1, axes=[dim]
let dim = dim.to_tokens();
Self::new(
input,
output,
UnaryNodeKind::ReduceMin,
Rc::new(move |input| quote! { #input.min_dim(#dim) }),
)
} else {
// ReduceMin, keepdims=0, axes=None
Self::new(
input,
output,
UnaryNodeKind::ReduceMin,
Rc::new(move |input| quote! { #input.min() }),
)
}
} else {
panic!("ReduceMin only supports tensor output");
}
}

pub(crate) fn reduce_mean(input: Type, output: Type, dim: Option<usize>) -> Self {
// ReduceMean is constrained to numeric tensors, so no need to check for bool.
if let Type::Tensor(_) = output {
Expand Down Expand Up @@ -629,6 +660,43 @@ mod tests {
);
}

#[test]
fn test_unary_codegen_reduce_min() {
one_node_graph(
UnaryNode::reduce_min(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Some(1),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.min_dim(1);

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);

one_node_graph(
UnaryNode::reduce_min(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 1)),
None,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
let tensor2 = tensor1.min();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}

#[test]
fn test_unary_codegen_reduce_mean() {
one_node_graph(
Expand Down
25 changes: 25 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Range => range_update_outputs(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMax => reduce_max_update_outputs(node),
NodeType::ReduceMin => reduce_min_update_outputs(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node),
NodeType::ReduceSum => reduce_sum_update_outputs(node),
NodeType::Relu => same_as_input(node),
Expand Down Expand Up @@ -716,6 +717,30 @@ fn reduce_max_update_outputs(node: &mut Node) {
}
}

fn reduce_min_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("ReduceMin: multiple inputs are not supported");
}
let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
let dim_only = match node.attrs.get("axes") {
Some(value) => match &value {
AttributeValue::Int64(_) => true,
AttributeValue::Int64s(ints) => ints.len() == 1,
_ => false,
},
None => false,
};
if dim_only {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
}
}

/// Infers the shape of a ReduceSum node and replaces the shape of the output tensor.
fn reduce_sum_update_outputs(node: &mut Node) {
let node_input = &mut node.inputs[0];
Expand Down
42 changes: 42 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,48 @@ pub fn reduce_max_config(node: &Node) -> Option<usize> {
}
}

pub fn reduce_min_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;

let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => axes = value.clone().into_i64s(),
"keepdims" => keepdims = value.clone().into_i64(),
_ => {}
}
}

if axes.len() > 1 {
panic!("ReduceMin: reducing on multiple dimensions is not supported")
}

if axes.is_empty() && keepdims == 1 {
panic!("ReduceMin: axes must be provided with keepdims")
}

if !axes.is_empty() && keepdims == 0 {
panic!("ReduceMin: the reduce operation must preserve the reduced dimension")
}

if axes.is_empty() {
None
} else {
let mut dim = axes[0];

if dim < 0 {
dim += tensor.dim as i64;
}
Some(dim as usize)
}
}

pub fn reduce_mean_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ impl OnnxGraph {
NodeType::Min => graph.register(Self::min_conversion(node)),
NodeType::Range => graph.register(Self::range_conversion(node)),
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
NodeType::ReduceMin => graph.register(Self::reduce_min_conversion(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
Expand Down Expand Up @@ -640,6 +641,14 @@ impl OnnxGraph {
UnaryNode::reduce_max(input, output, dim)
}

fn reduce_min_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let dim = reduce_min_config(&node);

UnaryNode::reduce_min(input, output, dim)
}

fn reduce_mean_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
Expand Down
Loading