Skip to content

Commit

Permalink
feat: added min onnx import (#1778)
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta committed May 22, 2024
1 parent 550086a commit 0918cf0
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ represent the corresponding Burn Op.
| [Mean][101] |||
| [MeanVarianceNormalization][102] |||
| [MelWeightMatrix][103] |||
| [Min][104] | ||
| [Min][104] | ||
| [Mish][105] |||
| [Mod][106] |||
| [Mul][107] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ fn main() {
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")
.input("tests/matmul/matmul.onnx")
.input("tests/min/min.onnx")
.input("tests/max/max.onnx")
.input("tests/maxpool1d/maxpool1d.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
Expand Down
17 changes: 17 additions & 0 deletions crates/burn-import/onnx-tests/tests/min/min.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pytorch2.3.0:�
(
onnx::Min_0
onnx::Min_12/Min"Min
main_graphZ
onnx::Min_0


Z
onnx::Min_1


b
2


B
38 changes: 38 additions & 0 deletions crates/burn-import/onnx-tests/tests/min/min.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/min/min.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
return torch.minimum(x, y)

def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

onnx_name = "min.onnx"

test_input1 = torch.randn(4, 4, device=device)
test_input2 = torch.randn(4, 4, device=device)
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

print("Test input data: {} {}".format(test_input1, test_input2))
output = model.forward(test_input1, test_input2)
print("Test output data: {}".format(output))

if __name__ == '__main__':
main()
15 changes: 15 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include_models!(
log,
mask_where,
matmul,
min,
max,
maxpool1d,
maxpool2d,
Expand Down Expand Up @@ -450,6 +451,20 @@ mod tests {
assert_eq!(output2, expected2);
}

#[test]
fn min() {
let device = Default::default();

let model: min::Model<Backend> = min::Model::new(&device);
let input1 = Tensor::<Backend, 2>::from_floats([[-1.0, 42.0, 0.0, 42.0]], &device);
let input2 = Tensor::<Backend, 2>::from_floats([[2.0, 4.0, 42.0, 25.0]], &device);

let output = model.forward(input1, input2);
let expected = Data::from([[-1.0, 4.0, 0.0, 25.0]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn max() {
let device = Default::default();
Expand Down
15 changes: 15 additions & 0 deletions crates/burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub enum BinaryType {
Equal,
Powf,
Powi,
Min,
Max,
}

Expand All @@ -27,6 +28,7 @@ impl BinaryType {
BinaryType::Equal => "equal",
BinaryType::Powi => "powi",
BinaryType::Powf => "powf",
BinaryType::Min => "min_pair",
BinaryType::Max => "max_pair",
}
}
Expand Down Expand Up @@ -176,6 +178,14 @@ impl BinaryNode {
Self::new(lhs, rhs, output, BinaryType::Powi, Arc::new(function))
}

pub(crate) fn min_pair(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.min_pair(#rhs) },
_ => panic!("min_pair is supported for tensor only"),
};
Self::new(lhs, rhs, output, BinaryType::Min, Arc::new(function))
}

pub(crate) fn max_pair(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.max_pair(#rhs) },
Expand Down Expand Up @@ -338,6 +348,11 @@ mod tests {
test_binary_operator_on_scalar_and_scalar!(div, /);
}

#[test]
fn test_binary_codegen_min() {
test_binary_operator_on_tensors!(min_pair);
}

#[test]
fn test_binary_codegen_max() {
test_binary_operator_on_tensors!(max_pair);
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::MatMul => matmul_update_outputs(node),
NodeType::Min => same_as_input(node),
NodeType::Max => same_as_input(node),
NodeType::MaxPool1d => same_as_input(node),
NodeType::MaxPool2d => same_as_input(node),
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ impl OnnxGraph {
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Min => graph.register(Self::min_conversion(node)),
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
Expand Down Expand Up @@ -501,6 +502,14 @@ impl OnnxGraph {
ReshapeNode::new(input, output, shape)
}

fn min_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();

BinaryNode::min_pair(lhs, rhs, output)
}

fn reduce_max_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
Expand Down

0 comments on commit 0918cf0

Please sign in to comment.