Skip to content

Commit

Permalink
Add where onnx op support (#1653)
Browse files Browse the repository at this point in the history
* Add where onnx op support

* Add broadcasting support

* Remove broadcasting limitation comment

* Fix broadcasting in mask where

* Forgot to reflect changes in codegen test

* Fix clippy
  • Loading branch information
laggui committed Apr 18, 2024
1 parent 7705fd9 commit 9fbcbed
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ represent the corresponding Burn Op.
| [Trilu][188] |||
| [Unique][189] |||
| [Upsample][190] |||
| [Where][191] | ||
| [Where][191] | ||
| [Xor][192] |||
| [Unsqueeze][193] |||

Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fn main() {
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/mask_where/mask_where.onnx")
.out_dir("model/")
.run_from_script();

Expand Down
38 changes: 38 additions & 0 deletions crates/burn-import/onnx-tests/tests/mask_where/mask_where.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
pytorch2.1.2:�
?
onnx::Where_0
onnx::Where_1
onnx::Where_25/Where"Where
A
onnx::Where_0
onnx::Where_3
onnx::Where_46/Where_1"Where
main_graphZ
onnx::Where_0
 

Z
onnx::Where_1


Z
onnx::Where_2


Z
onnx::Where_3


Z
onnx::Where_4


b
5


b
6


B
Expand Down
42 changes: 42 additions & 0 deletions crates/burn-import/onnx-tests/tests/mask_where/mask_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/mask_where/mask_where.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, condition, x1, y1, x2, y2):
return torch.where(condition, x1, y1), torch.where(condition, x2, y2)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "mask_where.onnx"
x = torch.ones(2, 2, device=device)
y = torch.zeros(2, 2, device=device)
mask = torch.tensor([[True, False], [False, True]], device=device)
test_input = (mask, x, y, x[0], y[0])

torch.onnx.export(model, (test_input), onnx_name, verbose=False, opset_version=16)

print(f"Finished exporting model to {onnx_name}")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")


if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ include_models!(
linear,
log_softmax,
log,
mask_where,
matmul,
maxpool2d,
mul,
Expand Down Expand Up @@ -1064,4 +1065,22 @@ mod tests {

assert_eq!(output_scalar, expected_scalar);
}

#[test]
fn mask_where() {
let device = Default::default();
let model: mask_where::Model<Backend> = mask_where::Model::new(&device);

let x1 = Tensor::ones([2, 2], &device);
let y1 = Tensor::zeros([2, 2], &device);
let x2 = Tensor::ones([2], &device);
let y2 = Tensor::zeros([2], &device);
let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device);

let (output, output_broadcasted) = model.forward(mask, x1, y1, x2, y2);
let expected = Data::from([[1.0, 0.0], [0.0, 1.0]]);

assert_eq!(output.to_data(), expected);
assert_eq!(output_broadcasted.to_data(), expected);
}
}
4 changes: 4 additions & 0 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::mask_where::WhereNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
Expand Down Expand Up @@ -92,6 +93,7 @@ pub enum Node<PS: PrecisionSettings> {
Reshape(ReshapeNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
}

macro_rules! match_all {
Expand All @@ -116,6 +118,7 @@ macro_rules! match_all {
Node::Reshape(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
}
}};
}
Expand Down Expand Up @@ -150,6 +153,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Reshape(_) => "reshape",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
}
}
}
Expand Down
200 changes: 200 additions & 0 deletions crates/burn-import/src/burn/node/mask_where.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
use core::cmp::max;

use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, TensorType, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct WhereNode {
/// Bool tensor. When True (nonzero), yield X, otherwise yield Y.
pub condition: TensorType,
/// Values selected at indices where condition is True.
pub x: TensorType,
/// Values selected at indices where condition is False.
pub y: TensorType,
pub output: TensorType,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for WhereNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![
Type::Tensor(self.condition.clone()),
Type::Tensor(self.x.clone()),
Type::Tensor(self.y.clone()),
]
}

fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
let mut mask = scope.tensor_use_owned(&self.condition, node_position);
let mut x = scope.tensor_use_owned(&self.x, node_position);
let mut y = scope.tensor_use_owned(&self.y, node_position);
let output = &self.output.name;

// x, y and condition need to be broadcastable
let broadcasted_dim = max(max(self.x.dim, self.y.dim), self.condition.dim);
let unsqueeze_dims = broadcasted_dim.to_tokens();

if self.condition.dim < broadcasted_dim {
mask = quote! { #mask.unsqueeze::<#unsqueeze_dims>()};
}

if self.x.dim < broadcasted_dim {
x = quote! { #x.unsqueeze::<#unsqueeze_dims>()};
}

if self.y.dim < broadcasted_dim {
y = quote! { #y.unsqueeze::<#unsqueeze_dims>()};
}

quote! {
let #output = #y.mask_where(#mask, #x);
}
}

fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::tensor::Bool");
}

fn into_node(self) -> super::Node<PS> {
Node::Where(self)
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{mask_where::WhereNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_where() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(WhereNode::new(
TensorType::new_bool("tensor1", 2),
TensorType::new_float("tensor2", 2),
TensorType::new_float("tensor3", 2),
TensorType::new_float("tensor4", 2),
));

graph.register_input_output(
vec![
"tensor1".to_string(),
"tensor2".to_string(),
"tensor3".to_string(),
],
vec!["tensor4".to_string()],
);

let expected = quote! {
use burn::tensor::Bool;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2, Bool>,
tensor2: Tensor<B, 2>,
tensor3: Tensor<B, 2>
) -> Tensor<B, 2> {
let tensor4 = tensor3.mask_where(tensor1, tensor2);

tensor4
}
}
};

assert_tokens(graph.codegen(), expected);
}

#[test]
fn test_codegen_where_broadcasted() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(WhereNode::new(
TensorType::new_bool("tensor1", 4),
TensorType::new_float("tensor2", 2),
TensorType::new_float("tensor3", 3),
TensorType::new_float("tensor4", 4),
));

graph.register_input_output(
vec![
"tensor1".to_string(),
"tensor2".to_string(),
"tensor3".to_string(),
],
vec!["tensor4".to_string()],
);

let expected = quote! {
use burn::tensor::Bool;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4, Bool>,
tensor2: Tensor<B, 2>,
tensor3: Tensor<B, 3>
) -> Tensor<B, 4> {
let tensor4 = tensor3
.unsqueeze::<4>()
.mask_where(tensor1, tensor2.unsqueeze::<4>());

tensor4
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub(crate) mod dropout;
pub(crate) mod gather;
pub(crate) mod global_avg_pool;
pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool2d;
pub(crate) mod reshape;
Expand Down
19 changes: 19 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Unsqueeze => unsqueeze_update_output(node),
NodeType::Pow => same_as_input(node),
NodeType::LeakyRelu => same_as_input(node),
NodeType::Where => where_update_outputs(node),
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
_ => temporary_pass_through_stub(node),
}
Expand Down Expand Up @@ -498,3 +499,21 @@ fn reduce_max_update_outputs(node: &mut Node) {
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
}
}

fn where_update_outputs(node: &mut Node) {
match (
node.inputs[0].ty.clone(),
node.inputs[1].ty.clone(),
node.inputs[2].ty.clone(),
) {
(ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => {
// With broadcasting support, output dim has to be computed based on the inputs
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: x.elem_type.clone(),
dim: max(condition.dim, max(x.dim, y.dim)),
..Default::default()
});
}
_ => panic!("Only tensor input is valid"),
}
}
Loading

0 comments on commit 9fbcbed

Please sign in to comment.