Skip to content

Commit

Permalink
feat: implement the right max operation (hopefully)
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta committed May 15, 2024
1 parent 946bf92 commit 131b2f6
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 54 deletions.
Binary file modified crates/burn-import/onnx-tests/tests/max/max.onnx
Binary file not shown.
13 changes: 7 additions & 6 deletions crates/burn-import/onnx-tests/tests/max/max.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return torch.max(x)
def forward(self, x, y):
return torch.maximum(x, y)

def main():
# Set seed for reproducibility
Expand All @@ -24,13 +24,14 @@ def main():

onnx_name = "max.onnx"

test_input = torch.randn(4, 4, device=device)
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
test_input1 = torch.randn(4, 4, device=device)
test_input2 = torch.randn(4, 4, device=device)
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test input data: {} {}".format(test_input1, test_input2))
output = model.forward(test_input1, test_input2)
print("Test output data: {}".format(output))

if __name__ == '__main__':
Expand Down
7 changes: 4 additions & 3 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,11 @@ mod tests {
let device = Default::default();

let model: max::Model<Backend> = max::Model::new(&device);
let input = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 25.0]], &device);
let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 42.0, 9.0, 42.0]], &device);
let input2 = Tensor::<Backend, 2>::from_floats([[42.0, 4.0, 42.0, 25.0]], &device);

let output = model.forward(input);
let expected = Data::from([25.0]);
let output = model.forward(input1, input2);
let expected = Data::from([[42.0, 42.0, 42.0, 42.0]]);

assert_eq!(output.to_data(), expected);
}
Expand Down
15 changes: 15 additions & 0 deletions crates/burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub enum BinaryType {
Equal,
Powf,
Powi,
Max,
}

impl BinaryType {
Expand All @@ -26,6 +27,7 @@ impl BinaryType {
BinaryType::Equal => "equal",
BinaryType::Powi => "powi",
BinaryType::Powf => "powf",
BinaryType::Max => "max_pair",
}
}
}
Expand Down Expand Up @@ -173,6 +175,14 @@ impl BinaryNode {
};
Self::new(lhs, rhs, output, BinaryType::Powi, Arc::new(function))
}

pub(crate) fn max_pair(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.max_pair(#rhs) },
_ => panic!("max is supported for tensor only"),
};
Self::new(lhs, rhs, output, BinaryType::Max, Arc::new(function))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -328,6 +338,11 @@ mod tests {
test_binary_operator_on_scalar_and_scalar!(div, /);
}

#[test]
fn test_binary_codegen_max() {
test_binary_operator_on_tensors!(max_pair);
}

#[test]
fn test_binary_codegen_equal_tensors() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
Expand Down
25 changes: 0 additions & 25 deletions crates/burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ pub enum UnaryNodeKind {
LogSoftmax,
Neg,
Not,
Max,
ReduceMax,
ReduceMean,
ReduceSum,
Expand Down Expand Up @@ -62,7 +61,6 @@ impl UnaryNodeKind {
Self::LogSoftmax => "log_softmax",
Self::Neg => "neg",
Self::Not => "not",
Self::Max => "max",
Self::ReduceMax => "reduce_max",
Self::ReduceMean => "reduce_mean",
Self::ReduceSum => "reduce_sum",
Expand Down Expand Up @@ -188,11 +186,6 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::LogSoftmax, Rc::new(function))
}

pub(crate) fn max(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.max() };
Self::new(input, output, UnaryNodeKind::Max, Rc::new(function))
}

pub(crate) fn softmax(input: Type, output: Type, dim: usize) -> Self {
let dim = dim.to_tokens();
let function = move |input| quote! { burn::tensor::activation::softmax(#input, #dim) };
Expand Down Expand Up @@ -540,24 +533,6 @@ mod tests {
);
}

#[test]
fn test_unary_codegen_max() {
one_node_graph(
UnaryNode::max(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 1)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
let tensor2 = tensor1.max();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}

#[test]
fn test_unary_codegen_softmax() {
one_node_graph(
Expand Down
15 changes: 1 addition & 14 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::MatMul => matmul_update_outputs(node),
NodeType::Max => max_update_outputs(node),
NodeType::Max => same_as_input(node),
NodeType::MaxPool1d => same_as_input(node),
NodeType::MaxPool2d => same_as_input(node),
NodeType::Mul => same_as_input(node),
Expand Down Expand Up @@ -432,19 +432,6 @@ fn matmul_update_outputs(node: &mut Node) {
}
}

fn max_update_outputs(node: &mut Node) {
let tensor = match node.inputs[0].clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: tensor.elem_type,
dim: 1,
shape: Some(vec![tensor.shape.unwrap().first().unwrap().clone()]),
})
}

/// Infers the shape of a ReduceMax node and replaces the shape of the output tensor.
fn reduce_max_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
Expand Down
14 changes: 8 additions & 6 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,14 @@ impl OnnxGraph {
BinaryNode::equal(lhs, rhs, output)
}

fn max_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();

BinaryNode::max_pair(lhs, rhs, output)
}

fn erf_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
Expand Down Expand Up @@ -599,12 +607,6 @@ impl OnnxGraph {
UnaryNode::tanh(input, output)
}

fn max_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
UnaryNode::max(input, output)
}

fn concat_conversion(node: Node) -> ConcatNode {
let inputs = node
.inputs
Expand Down

0 comments on commit 131b2f6

Please sign in to comment.