diff --git a/wonnx-cli/src/cpu.rs b/wonnx-cli/src/cpu.rs index e543056e..3d2e1331 100644 --- a/wonnx-cli/src/cpu.rs +++ b/wonnx-cli/src/cpu.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use tract_onnx::prelude::*; use wonnx::{ onnx::ModelProto, - utils::{OutputTensor, Shape}, + tensor::{Shape, TensorData}, }; type RunnableOnnxModel = @@ -73,7 +73,7 @@ impl Inferer for CPUInferer { outputs: &[String], inputs: &HashMap, model: &ModelProto, - ) -> Result, NNXError> { + ) -> Result>, NNXError> { let mut cpu_inputs: HashMap = HashMap::new(); for (input_name, input_tensor) in inputs { @@ -86,12 +86,7 @@ impl Inferer for CPUInferer { .unwrap_or_else(|| panic!("input not found with name {}", input_name)); log::info!("set input fact {} for cpu model", input_index.0,); - let dims: Vec = self.input_shapes[input_name] - .dims - .iter() - .map(|x| (*x) as usize) - .collect(); - + let dims: Vec = self.input_shapes[input_name].dims.to_vec(); cpu_inputs.insert(input_index.0, input_tensor.to_tract_tensor(&dims)?); } @@ -103,7 +98,7 @@ impl Inferer for CPUInferer { let result = self.model.run(cpu_inputs_ordered)?; log::debug!("cpu result: {:?}", result); - let mut output_tensors = HashMap::::new(); + let mut output_tensors = HashMap::::new(); for output_name in outputs { let result_vector = { @@ -132,7 +127,7 @@ impl Inferer for CPUInferer { let av = result_vector.to_array_view()?; output_tensors.insert( output_name.clone(), - OutputTensor::F32(av.as_slice().unwrap().to_vec()), + TensorData::F32(av.as_slice().unwrap().into()).into_static(), ); } Ok(output_tensors) diff --git a/wonnx-cli/src/gpu.rs b/wonnx-cli/src/gpu.rs index c3067d37..cd45a5bb 100644 --- a/wonnx-cli/src/gpu.rs +++ b/wonnx-cli/src/gpu.rs @@ -4,7 +4,7 @@ use wonnx::onnx::ModelProto; use wonnx::SessionConfig; use async_trait::async_trait; -use wonnx::utils::OutputTensor; +use wonnx::tensor::TensorData; use crate::types::Inferer; use crate::types::NNXError; @@ -33,14 +33,14 @@ impl Inferer for GPUInferer { outputs: &[String], inputs: &HashMap, _model: &ModelProto, - ) -> Result, NNXError> { + ) -> Result>, NNXError> { let input_refs = inputs .iter() .map(|(k, v)| (k.clone(), v.input_tensor())) .collect(); let mut result = self.session.run(&input_refs).await.expect("run failed"); - let mut output_tensors = HashMap::::new(); + let mut output_tensors = HashMap::::new(); for output_name in outputs { let result = match result.remove(output_name) { diff --git a/wonnx-cli/src/info.rs b/wonnx-cli/src/info.rs index 3e7ab707..24761c7c 100644 --- a/wonnx-cli/src/info.rs +++ b/wonnx-cli/src/info.rs @@ -4,7 +4,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use prettytable::{row, table, Table}; use wonnx::{ onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto}, - utils::{ScalarType, Shape}, + tensor::{ScalarType, Shape}, WonnxError, }; @@ -31,7 +31,14 @@ fn dimensions_infos( } for info in graph_proto.get_initializer() { - let shape = Shape::from(ScalarType::from_i32(info.get_data_type())?, info.get_dims()); + let shape = Shape::from( + ScalarType::from_onnx_i32(info.get_data_type())?, + &info + .get_dims() + .iter() + .map(|x| *x as usize) + .collect::>(), + ); shapes_info.insert(info.get_name().to_string(), Some(shape)); } @@ -112,7 +119,14 @@ pub fn sizes_table(model: &ModelProto) -> Result { let mut initializer_size: usize = 0; for info in model.get_graph().get_initializer() { - let shape = Shape::from(ScalarType::from_i32(info.get_data_type())?, info.get_dims()); + let shape = Shape::from( + ScalarType::from_onnx_i32(info.get_data_type())?, + &info + .get_dims() + .iter() + .map(|x| *x as usize) + .collect::>(), + ); initializer_size += shape.buffer_bytes_aligned(); } diff --git a/wonnx-cli/src/main.rs b/wonnx-cli/src/main.rs index 3cb49880..4f3750ea 100644 --- a/wonnx-cli/src/main.rs +++ b/wonnx-cli/src/main.rs @@ -8,7 +8,8 @@ use std::fs::File; use structopt::StructOpt; use trace::trace_command; use wonnx::onnx::ModelProto; -use wonnx::utils::{get_opset_version, OutputTensor, Shape}; +use wonnx::onnx_model::get_opset_version; +use wonnx::tensor::{Shape, TensorData}; use wonnx_preprocessing::shape_inference::{apply_dynamic_dimensions, infer_shapes}; use wonnx_preprocessing::text::{get_lines, EncodedText}; use wonnx_preprocessing::Tensor; @@ -105,7 +106,7 @@ async fn run() -> Result<(), NNXError> { fn print_qa_output( infer_opt: &InferOptions, qa_encoding: &EncodedText, - mut outputs: HashMap, + mut outputs: HashMap, ) -> Result<(), NNXError> { let start_output: Vec = outputs .remove(&infer_opt.qa_answer_start) @@ -133,7 +134,7 @@ fn print_qa_output( fn print_output( infer_opt: &InferOptions, output_name: &str, - output: OutputTensor, + output: TensorData, print_output_names: bool, print_newlines: bool, ) { @@ -165,8 +166,8 @@ fn print_output( // Just print the output tensor values, one a line match output { - wonnx::utils::OutputTensor::F32(fs) => { - for i in fs { + wonnx::tensor::TensorData::F32(fs) => { + for i in fs.iter() { if print_newlines { println!("{:.3}", i); } else { @@ -174,8 +175,8 @@ fn print_output( } } } - wonnx::utils::OutputTensor::I32(ints) => { - for i in ints { + wonnx::tensor::TensorData::I32(ints) => { + for i in ints.iter() { if print_newlines { println!("{}", i); } else { @@ -183,8 +184,8 @@ fn print_output( } } } - wonnx::utils::OutputTensor::I64(ints) => { - for i in ints { + wonnx::tensor::TensorData::I64(ints) => { + for i in ints.iter() { if print_newlines { println!("{}", i); } else { @@ -192,8 +193,8 @@ fn print_output( } } } - wonnx::utils::OutputTensor::U8(ints) => { - for i in ints { + wonnx::tensor::TensorData::U8(ints) => { + for i in ints.iter() { if print_newlines { println!("{}", i); } else { @@ -238,7 +239,10 @@ async fn prepare_command(prepare_opt: PrepareOptions) -> Result<(), NNXError> { { Some(input) => { let data_type = input.data_type().map_err(|_| NNXError::InvalidInputShape)?; - let shape = Shape::from(data_type, &new_dims); + let shape = Shape::from( + data_type, + &new_dims.iter().map(|x| *x as usize).collect::>(), + ); input.set_shape(&shape); log::info!("setting shape of input {input_name} to {shape}"); } diff --git a/wonnx-cli/src/types.rs b/wonnx-cli/src/types.rs index 9a95dcf0..3579e3b3 100644 --- a/wonnx-cli/src/types.rs +++ b/wonnx-cli/src/types.rs @@ -4,7 +4,8 @@ use structopt::StructOpt; use thiserror::Error; use wonnx::{ onnx::ModelProto, - utils::{OpsetError, OutputTensor, Shape, TensorConversionError}, + onnx_model::OpsetError, + tensor::{Shape, TensorConversionError, TensorData}, SessionError, WonnxError, }; use wonnx_preprocessing::{ @@ -296,7 +297,7 @@ pub trait Inferer { outputs: &[String], inputs: &HashMap, model: &ModelProto, - ) -> Result, NNXError>; + ) -> Result>, NNXError>; } pub struct InferenceInput { diff --git a/wonnx-cli/src/utils.rs b/wonnx-cli/src/utils.rs index 9d44ea1a..da714da0 100644 --- a/wonnx-cli/src/utils.rs +++ b/wonnx-cli/src/utils.rs @@ -2,7 +2,7 @@ use ndarray::{Array, ArrayBase}; use std::collections::HashMap; use std::path::Path; use wonnx::onnx::{ModelProto, TensorShapeProto, ValueInfoProto}; -use wonnx::utils::{DataTypeError, ScalarType, Shape}; +use wonnx::tensor::{DataTypeError, ScalarType, Shape}; use wonnx::WonnxError; use wonnx_preprocessing::image::{load_bw_image, load_rgb_image}; use wonnx_preprocessing::text::{EncodedText, TextTokenizer}; @@ -55,7 +55,7 @@ impl ValueInfoProtoUtil for ValueInfoProto { Ok(match &self.get_field_type().value { Some(x) => match x { wonnx::onnx::TypeProto_oneof_value::tensor_type(t) => { - ScalarType::from_i32(t.get_elem_type())? + ScalarType::from_onnx_i32(t.get_elem_type())? } wonnx::onnx::TypeProto_oneof_value::sequence_type(_) => todo!(), wonnx::onnx::TypeProto_oneof_value::map_type(_) => todo!(), @@ -119,8 +119,8 @@ pub fn load_image_input( input_shape: &Shape, ) -> Result, ndarray::IxDyn>, NNXError> { if input_shape.rank() == 3 { - let mut w = input_shape.dim(1) as usize; - let mut h = input_shape.dim(2) as usize; + let mut w = input_shape.dim(1); + let mut h = input_shape.dim(2); if w == 0 { w = 224; } @@ -138,8 +138,8 @@ pub fn load_image_input( Err(NNXError::InvalidInputShape) } } else if input_shape.rank() == 4 { - let mut w = input_shape.dim(2) as usize; - let mut h = input_shape.dim(3) as usize; + let mut w = input_shape.dim(2); + let mut h = input_shape.dim(3); if w == 0 { w = 224; } @@ -179,12 +179,12 @@ impl InferenceInput { .get_input_shape(&infer_opt.qa_segment_input)? .ok_or_else(|| NNXError::InputNotFound(infer_opt.qa_segment_input.clone()))?; - let segment_length = tokens_input_shape.element_count() as usize; + let segment_length = tokens_input_shape.element_count(); - if segment_length != mask_input_shape.element_count() as usize { + if segment_length != mask_input_shape.element_count() { return Err(NNXError::InvalidInputShape); } - if segment_length != segment_input_shape.element_count() as usize { + if segment_length != segment_input_shape.element_count() { return Err(NNXError::InvalidInputShape); } @@ -272,7 +272,7 @@ impl InferenceInput { let values: Result, _> = text.split(',').map(|v| v.parse::()).collect(); let mut values = values.map_err(NNXError::InvalidNumber)?; - values.resize(raw_input_shape.element_count() as usize, 0.0); + values.resize(raw_input_shape.element_count(), 0.0); inputs.insert( raw_input_name.clone(), Tensor::F32(Array::from_vec(values).into_dyn()), diff --git a/wonnx-preprocessing/src/constant_folding.rs b/wonnx-preprocessing/src/constant_folding.rs index b8baee1a..3d7080d3 100644 --- a/wonnx-preprocessing/src/constant_folding.rs +++ b/wonnx-preprocessing/src/constant_folding.rs @@ -5,17 +5,18 @@ use thiserror::Error; use wonnx::{ constant_of_shape_output, + ir::{IrError, OperatorDefinition}, onnx::{ GraphProto, NodeProto, TensorProto, TensorShapeProto, TensorShapeProto_Dimension, TypeProto, TypeProto_Tensor, ValueInfoProto, }, - utils::{ - model_with_opset, DataTypeError, InputTensor, NodeAttributes, OutputTensor, ScalarType, - Shape, - }, + onnx_model::onnx_model_with_opset, + tensor::{DataTypeError, ScalarType, Shape, TensorData}, CompileError, GpuError, Session, SessionError, }; +use crate::utils::NodeAttributes; + #[derive(Error, Debug)] pub enum ConstantFoldingError { #[error("unsupported data type encountered: {0}")] @@ -28,69 +29,72 @@ pub enum ConstantFoldingError { #[error("error calculating constant value: {0}")] #[from(SessionError)] CalculationError(SessionError), + + #[error("error in IR: {0}")] + #[from(IrError)] + IrError(IrError), } pub(crate) async fn calculate_constant_node_outputs<'a>( node: &'a NodeProto, shapes: &'a HashMap, - inputs: &'a [InputTensor<'a>], + inputs: &'a [TensorData<'a>], output_shapes: &[Shape], _initializers: &HashMap>, opset_version: i64, -) -> Result>, ConstantFoldingError> { +) -> Result>>, ConstantFoldingError> { Ok(match node.get_op_type() { - "Identity" | "Unsqueeze" | "Squeeze" | "Reshape" => { - Some(inputs.iter().map(OutputTensor::from).collect()) - } + "Identity" | "Unsqueeze" | "Squeeze" | "Reshape" => Some(inputs.to_vec()), "Cast" => { - let cast_to_type = - ScalarType::from_i32(node.get_attribute_value::("to", None).map_err(|_| { + let cast_to_type = ScalarType::from_onnx_i32( + node.get_attribute_value::("to", None).map_err(|_| { ConstantFoldingError::InvalidNode("to attribute missing for Cast ".to_string()) - })? as i32) - .map_err(ConstantFoldingError::UnsupportedDataType)?; + })? as i32, + ) + .map_err(ConstantFoldingError::UnsupportedDataType)?; let input_tensor = &inputs[0]; - let output_tensor = match (input_tensor, cast_to_type) { - (InputTensor::F32(v), ScalarType::F32) => OutputTensor::F32(v.to_vec()), - (InputTensor::F32(v), ScalarType::I64) => { - OutputTensor::I64(v.iter().map(|x| *x as i64).collect()) + let output_tensor: TensorData<'static> = match (input_tensor, cast_to_type) { + (TensorData::F32(v), ScalarType::F32) => TensorData::F32(Cow::Owned(v.to_vec())), + (TensorData::F32(v), ScalarType::I64) => { + TensorData::I64(v.iter().map(|x| *x as i64).collect()) } - (InputTensor::F32(v), ScalarType::I32) => { - OutputTensor::I32(v.iter().map(|x| *x as i32).collect()) + (TensorData::F32(v), ScalarType::I32) => { + TensorData::I32(v.iter().map(|x| *x as i32).collect()) } - (InputTensor::F32(v), ScalarType::U8) => { - OutputTensor::U8(v.iter().map(|x| *x as u8).collect()) + (TensorData::F32(v), ScalarType::U8) => { + TensorData::U8(v.iter().map(|x| *x as u8).collect()) } - (InputTensor::I32(v), ScalarType::F32) => { - OutputTensor::F32(v.iter().map(|x| *x as f32).collect()) + (TensorData::I32(v), ScalarType::F32) => { + TensorData::F32(v.iter().map(|x| *x as f32).collect()) } - (InputTensor::I32(v), ScalarType::I64) => { - OutputTensor::I64(v.iter().map(|x| *x as i64).collect()) + (TensorData::I32(v), ScalarType::I64) => { + TensorData::I64(v.iter().map(|x| *x as i64).collect()) } - (InputTensor::I32(v), ScalarType::I32) => OutputTensor::I32(v.to_vec()), - (InputTensor::I32(v), ScalarType::U8) => { - OutputTensor::U8(v.iter().map(|x| *x as u8).collect()) + (TensorData::I32(v), ScalarType::I32) => TensorData::I32(Cow::Owned(v.to_vec())), + (TensorData::I32(v), ScalarType::U8) => { + TensorData::U8(v.iter().map(|x| *x as u8).collect()) } - (InputTensor::I64(v), ScalarType::F32) => { - OutputTensor::F32(v.iter().map(|x| *x as f32).collect()) + (TensorData::I64(v), ScalarType::F32) => { + TensorData::F32(v.iter().map(|x| *x as f32).collect()) } - (InputTensor::I64(v), ScalarType::I64) => OutputTensor::I64(v.to_vec()), - (InputTensor::I64(v), ScalarType::I32) => { - OutputTensor::I32(v.iter().map(|x| *x as i32).collect()) + (TensorData::I64(v), ScalarType::I64) => TensorData::I64(Cow::Owned(v.to_vec())), + (TensorData::I64(v), ScalarType::I32) => { + TensorData::I32(v.iter().map(|x| *x as i32).collect()) } - (InputTensor::I64(v), ScalarType::U8) => { - OutputTensor::U8(v.iter().map(|x| *x as u8).collect()) + (TensorData::I64(v), ScalarType::U8) => { + TensorData::U8(v.iter().map(|x| *x as u8).collect()) } - (InputTensor::U8(v), ScalarType::F32) => { - OutputTensor::F32(v.iter().map(|x| *x as f32).collect()) + (TensorData::U8(v), ScalarType::F32) => { + TensorData::F32(v.iter().map(|x| *x as f32).collect()) } - (InputTensor::U8(v), ScalarType::I64) => { - OutputTensor::I64(v.iter().map(|x| *x as i64).collect()) + (TensorData::U8(v), ScalarType::I64) => { + TensorData::I64(v.iter().map(|x| *x as i64).collect()) } - (InputTensor::U8(v), ScalarType::I32) => { - OutputTensor::I32(v.iter().map(|x| *x as i32).collect()) + (TensorData::U8(v), ScalarType::I32) => { + TensorData::I32(v.iter().map(|x| *x as i32).collect()) } - (InputTensor::U8(v), ScalarType::U8) => OutputTensor::U8(v.to_vec()), + (TensorData::U8(v), ScalarType::U8) => TensorData::U8(Cow::Owned(v.to_vec())), }; Some(vec![output_tensor]) @@ -104,9 +108,10 @@ pub(crate) async fn calculate_constant_node_outputs<'a>( // ConstantOfShape: produces an output of the shape specified by the input, filled with a constant value specified in an attribute "ConstantOfShape" => { - if let InputTensor::I64(input_shape) = &inputs[0] { + if let TensorData::I64(input_shape) = &inputs[0] { let element_count = input_shape.iter().product::() as usize; - Some(vec![constant_of_shape_output(node, element_count) + let op_def = OperatorDefinition::from(node, output_shapes.to_vec()); + Some(vec![constant_of_shape_output(&op_def, element_count) .map_err(|e| { ConstantFoldingError::InvalidNode(e.to_string()) })?]) @@ -159,7 +164,7 @@ pub(crate) async fn calculate_constant_node_outputs<'a>( )); graph.set_node(RepeatedField::from(vec![temp_node])); - let model = model_with_opset(graph, opset_version); + let model = onnx_model_with_opset(graph, opset_version); let session = match Session::from_model(model).await { Ok(v) => v, @@ -177,9 +182,9 @@ pub(crate) async fn calculate_constant_node_outputs<'a>( } }; - let mut named_inputs: HashMap = HashMap::new(); + let mut named_inputs: HashMap = HashMap::new(); for (index, input) in inputs.iter().enumerate() { - let input: InputTensor = input.to_owned(); + let input: TensorData = input.to_owned(); named_inputs.insert(format!("input_{}", index), input); } @@ -188,7 +193,7 @@ pub(crate) async fn calculate_constant_node_outputs<'a>( .await .map_err(ConstantFoldingError::CalculationError)?; - let outputs: Vec = (0..node.output.len()) + let outputs: Vec = (0..node.output.len()) .map(|output_index| { let output_key = format!("output_{}", output_index); output_values.remove(&output_key).unwrap() @@ -202,7 +207,7 @@ pub(crate) async fn calculate_constant_node_outputs<'a>( fn input_to_value_info(shape: &Shape, name: &str) -> ValueInfoProto { let mut ttp = TypeProto_Tensor::new(); - ttp.set_elem_type(shape.data_type.to_datatype().value()); + ttp.set_elem_type(shape.data_type.to_onnx_datatype().value()); let mut tsp = TensorShapeProto::new(); tsp.set_dim(RepeatedField::from( shape @@ -224,10 +229,10 @@ fn input_to_value_info(shape: &Shape, name: &str) -> ValueInfoProto { vip } -fn calculate_shape_operator( +fn calculate_shape_operator<'a>( node: &NodeProto, input_shape: &Shape, -) -> Result { +) -> Result, ConstantFoldingError> { let input_dims: Vec = input_shape.dims.iter().map(|x| *x as i64).collect(); let mut start = node.get_attribute_value("start", Some(0)).unwrap(); let mut end = node @@ -254,12 +259,15 @@ fn calculate_shape_operator( log::warn!("Shape operator results in an empty output shape which is probably an issue... start={start} end={end} input_shape={}", input_shape); } - Ok(OutputTensor::I64(output_shape)) + Ok(TensorData::I64(output_shape.into())) } #[cfg(test)] mod test { - use wonnx::utils::{attribute, node, OutputTensor, Shape}; + use wonnx::{ + onnx_model::{onnx_attribute, onnx_node}, + tensor::{Shape, TensorData}, + }; use super::calculate_shape_operator; @@ -271,16 +279,19 @@ mod test { ) { let mut attrs = vec![]; if let Some(start) = start { - attrs.push(attribute("start", start)); + attrs.push(onnx_attribute("start", start)); } if let Some(end) = end { - attrs.push(attribute("end", end)); + attrs.push(onnx_attribute("end", end)); } - let node = node(vec!["X"], vec!["Y"], "s", "Shape", attrs); - let shape = Shape::from(wonnx::utils::ScalarType::F32, dims); + let node = onnx_node(vec!["X"], vec!["Y"], "Shape", attrs); + let shape = Shape::from( + wonnx::tensor::ScalarType::F32, + &dims.iter().map(|x| *x as usize).collect::>(), + ); assert_eq!( calculate_shape_operator(&node, &shape).unwrap(), - OutputTensor::I64(out_dims.to_vec()) + TensorData::I64(out_dims.into()) ); } diff --git a/wonnx-preprocessing/src/lib.rs b/wonnx-preprocessing/src/lib.rs index 13945d13..cdde4a56 100644 --- a/wonnx-preprocessing/src/lib.rs +++ b/wonnx-preprocessing/src/lib.rs @@ -1,10 +1,11 @@ use ndarray::ArrayBase; -use wonnx::utils::InputTensor; +use wonnx::tensor::TensorData; pub mod constant_folding; pub mod image; pub mod shape_inference; pub mod text; +mod utils; pub enum Tensor { F32(ArrayBase, ndarray::IxDyn>), @@ -13,7 +14,7 @@ pub enum Tensor { } impl Tensor { - pub fn input_tensor(&self) -> InputTensor { + pub fn input_tensor(&self) -> TensorData { match self { Tensor::F32(a) => a.as_slice().unwrap().into(), Tensor::I32(a) => a.as_slice().unwrap().into(), diff --git a/wonnx-preprocessing/src/shape_inference.rs b/wonnx-preprocessing/src/shape_inference.rs index d20cb693..df268792 100644 --- a/wonnx-preprocessing/src/shape_inference.rs +++ b/wonnx-preprocessing/src/shape_inference.rs @@ -7,12 +7,13 @@ use wonnx::{ GraphProto, NodeProto, TensorProto, TensorShapeProto, TensorShapeProto_Dimension, TypeProto, TypeProto_Tensor, TypeProto_oneof_value, ValueInfoProto, }, - utils::{ - AttributeNotFoundError, DataTypeError, InputTensor, NodeAttributes, ScalarType, Shape, - }, + tensor::{DataTypeError, ScalarType, Shape, TensorData}, }; -use crate::constant_folding::{calculate_constant_node_outputs, ConstantFoldingError}; +use crate::{ + constant_folding::{calculate_constant_node_outputs, ConstantFoldingError}, + utils::{AttributeNotFoundError, NodeAttributes}, +}; pub fn apply_dynamic_dimensions(graph: &mut GraphProto, dynamic_dims: &HashMap) { // Apply to values @@ -40,7 +41,7 @@ fn static_initializer_value_i64<'a>( name: &str, ) -> Result<&'a [i64], ShapeInferenceError> { if let Some(shape_tensor) = initializers.get(name) { - if shape_tensor.get_data_type() != ScalarType::I64.to_datatype().value() { + if shape_tensor.get_data_type() != ScalarType::I64.to_onnx_datatype().value() { return Err(ShapeInferenceError::Unsupported(format!( "initializer {} has data type {} and not int64, which is currently not supported", name, @@ -148,8 +149,15 @@ pub(crate) fn dimensions_infos( } for info in graph_proto.get_initializer() { - if let Ok(data_type) = ScalarType::from_i32(info.get_data_type()) { - let shape = Shape::from(data_type, info.get_dims()); + if let Ok(data_type) = ScalarType::from_onnx_i32(info.get_data_type()) { + let shape = Shape::from( + data_type, + &info + .get_dims() + .iter() + .map(|x| *x as usize) + .collect::>(), + ); if shapes_info .insert(info.get_name().to_string(), shape) .is_some() @@ -217,21 +225,21 @@ fn replace_constant_ops_with_initializers( // Get constant value if let Ok(values) = node.get_attribute_value::>("value_floats", None) { - initializer.set_data_type(ScalarType::F32.to_datatype().value()); + initializer.set_data_type(ScalarType::F32.to_onnx_datatype().value()); initializer.set_dims(vec![values.len() as i64]); initializer.set_float_data(values); } else if let Ok(values) = node.get_attribute_value::>("value_ints", None) { - initializer.set_data_type(ScalarType::I64.to_datatype().value()); + initializer.set_data_type(ScalarType::I64.to_onnx_datatype().value()); initializer.set_dims(vec![values.len() as i64]); initializer.set_int64_data(values); } else if let Ok(values) = node.get_attribute_value::("value_int", None) { initializer.set_int64_data(vec![values]); - initializer.set_data_type(ScalarType::I64.to_datatype().value()); + initializer.set_data_type(ScalarType::I64.to_onnx_datatype().value()); initializer.set_dims(vec![1]); } else if let Ok(values) = node.get_attribute_value::("value_float", None) { initializer.set_float_data(vec![values]); - initializer.set_data_type(ScalarType::F32.to_datatype().value()); + initializer.set_data_type(ScalarType::F32.to_onnx_datatype().value()); initializer.set_dims(vec![1]); } else if let Ok(tp) = node.get_attribute_value::("value", None) { initializer = tp; @@ -344,7 +352,7 @@ pub async fn infer_shapes( let mut tip = TypeProto::new(); let mut ttp = TypeProto_Tensor::new(); - ttp.set_elem_type(output_shape.data_type.to_datatype().value()); + ttp.set_elem_type(output_shape.data_type.to_onnx_datatype().value()); let mut tsp = TensorShapeProto::new(); tsp.set_dim( @@ -379,16 +387,16 @@ pub async fn infer_shapes( log::debug!("node '{}' can be folded", node.get_name()); // Collect constant inputs - let inputs: Vec = node + let inputs: Vec = node .input .iter() .map(|input_name| { if let Some(initializer) = initializers.get(input_name) { - InputTensor::try_from(initializer.as_ref()) + Ok(TensorData::try_from(initializer.as_ref())?.into_static()) } else { // This should only happen when is_known_shape is true. In this case we will not do any GPU inference // and the contents if this tensor don't matter - Ok(InputTensor::I64(Cow::Owned(vec![]))) + Ok(TensorData::I64(Cow::Owned(vec![]))) } }) .collect::>() @@ -493,7 +501,7 @@ pub(crate) fn infer_output_shapes( let to_value: i64 = node .get_attribute_value("to", None) .map_err(ShapeInferenceError::MissingAttribute)?; - let to_data_type = ScalarType::from_i32(to_value as i32).map_err(|_| { + let to_data_type = ScalarType::from_onnx_i32(to_value as i32).map_err(|_| { ShapeInferenceError::InvalidNode( node.get_name().to_string(), format!( @@ -528,11 +536,11 @@ pub(crate) fn infer_output_shapes( let outer_dim = if axis == 0 { 1 } else { - input_dims[0..=(axis - 1)].iter().product::() as i64 + input_dims[0..=(axis - 1)].iter().product::() as i64 }; - let inner_dim = input_dims[axis..].iter().product::() as i64; + let inner_dim = input_dims[axis..].iter().product::() as i64; - let new_dims = vec![outer_dim, inner_dim]; + let new_dims = vec![outer_dim as usize, inner_dim as usize]; Ok(vec![Shape::from(input_shapes[0].data_type, &new_dims)]) } @@ -579,14 +587,14 @@ pub(crate) fn infer_output_shapes( (0..out_rank) .map(|idx| { if idx < axis { - input_shapes[0].dim(idx as usize) as i64 + input_shapes[0].dim(idx as usize) } else if idx >= axis && idx < (axis + q) { - input_shapes[1].dim((idx - axis) as usize) as i64 + input_shapes[1].dim((idx - axis) as usize) } else { - input_shapes[0].dim((idx - q + 1) as usize) as i64 + input_shapes[0].dim((idx - q + 1) as usize) } }) - .collect::>() + .collect::>() .as_ref(), )]) } @@ -604,7 +612,7 @@ pub(crate) fn infer_output_shapes( Ok(vec![Shape::from( ScalarType::I64, - &[rank.clamp(start, end)], + &[rank.clamp(start, end) as usize], )]) } @@ -691,8 +699,7 @@ pub(crate) fn infer_output_shapes( }) .collect(); - let mut output_shape: Vec = - input_shapes[0].dims.iter().map(|x| *x as i64).collect(); + let mut output_shape: Vec = input_shapes[0].dims.clone(); // https://github.com/onnx/onnx/blob/fb80e3ade84e9f406711aa41b9f3665753158371/onnx/defs/tensor/defs.cc#L969 for (axis_index, axis) in axes.iter().enumerate() { @@ -706,7 +713,7 @@ pub(crate) fn infer_output_shapes( &mut step, )?; let temp = div_ceil(end - start, step).max(0); - output_shape[*axis as usize] = temp; + output_shape[*axis as usize] = temp as usize; } Ok(vec![Shape::from(data_shape.data_type, &output_shape)]) @@ -752,7 +759,7 @@ pub(crate) fn infer_output_shapes( (0..input_ndim as i64) .flat_map(|i| { if !axes.contains(&i) { - vec![input_shape.dim(i as usize) as i64] + vec![input_shape.dim(i as usize)] } else if keep_dims == 1 { vec![1] } else { @@ -919,10 +926,10 @@ pub(crate) fn infer_output_shapes( }; // Determine output shape - let mut output_shape: Vec = vec![]; - output_shape.push(input_shape.dim(0) as i64); + let mut output_shape: Vec = vec![]; + output_shape.push(input_shape.dim(0)); if require_kernel_shape { - output_shape.push(input_shape.dim(1) as i64); + output_shape.push(input_shape.dim(1)); } else { if input_shapes[1].rank() < 1 { return Err(ShapeInferenceError::InvalidNode( @@ -930,7 +937,7 @@ pub(crate) fn infer_output_shapes( "second input has incorrect rank".to_string(), )); } - output_shape.push(input_shapes[1].dim(0) as i64); + output_shape.push(input_shapes[1].dim(0)); } let kernel_shape_size = kernel_shape.len(); @@ -951,7 +958,7 @@ pub(crate) fn infer_output_shapes( (effective_input_size - effective_kernel_shape[i]) / strides[i] }; - output_shape.push(1 + strided_kernel_positions); + output_shape.push((1 + strided_kernel_positions) as usize); } // MaxPool can have two outputs @@ -968,30 +975,36 @@ pub(crate) fn infer_output_shapes( .get_attribute_value::("value", None) .map_err(ShapeInferenceError::MissingAttribute)?; - let data_type = ScalarType::from_i32(value.get_data_type()) + let data_type = ScalarType::from_onnx_i32(value.get_data_type()) .map_err(ShapeInferenceError::UnsupportedDataType)?; - Ok(vec![Shape::from(data_type, shape)]) + Ok(vec![Shape::from( + data_type, + &shape.iter().map(|x| *x as usize).collect::>(), + )]) } ("Constant", 0, 1) => { if let Ok(values) = node.get_attribute_value::>("value_floats", None) { - Ok(vec![Shape::from(ScalarType::F32, &[values.len() as i64])]) + Ok(vec![Shape::from(ScalarType::F32, &[values.len()])]) } else if let Ok(values) = node.get_attribute_value::>("value_ints", None) { - Ok(vec![Shape::from(ScalarType::I64, &[values.len() as i64])]) + Ok(vec![Shape::from(ScalarType::I64, &[values.len()])]) } else if node.get_attribute_value::("value_float", None).is_ok() { Ok(vec![Shape::from(ScalarType::F32, &[1])]) } else if node.get_attribute_value::("value_int", None).is_ok() { Ok(vec![Shape::from(ScalarType::I64, &[1])]) } else if let Ok(tp) = node.get_attribute_value::("value", None) { Ok(vec![Shape::from( - ScalarType::from_i32(tp.get_data_type()).map_err(|_| { + ScalarType::from_onnx_i32(tp.get_data_type()).map_err(|_| { ShapeInferenceError::InvalidNode( node.get_name().to_string(), "invalid tensor data type".to_string(), ) })?, - tp.get_dims(), + &tp.get_dims() + .iter() + .map(|x| *x as usize) + .collect::>(), )]) } else { log::debug!("{:#?}", node); @@ -1021,22 +1034,22 @@ pub(crate) fn infer_output_shapes( } } - let output_shape: Vec = shape_tensor_contents + let output_shape: Vec = shape_tensor_contents .iter() .enumerate() .map(|(idx, dim)| { if *dim == 0 && !allow_zero { - input_shapes[0].dim(idx) as i64 + input_shapes[0].dim(idx) } else { - *dim + *dim as usize } }) .collect(); - if output_shape.iter().product::() != input_shapes[0].element_count() as i64 { + if output_shape.iter().product::() != input_shapes[0].element_count() { return Err(ShapeInferenceError::InvalidNode( node.get_name().to_string(), - format!("Reshape input tensor (element count={}) must have the same number of elements as specified by the new shape ({})", input_shapes[0].element_count(), output_shape.iter().product::()))); + format!("Reshape input tensor (element count={}) must have the same number of elements as specified by the new shape ({})", input_shapes[0].element_count(), output_shape.iter().product::()))); } Ok(vec![Shape::from(input_shapes[0].data_type, &output_shape)]) @@ -1053,7 +1066,7 @@ pub(crate) fn infer_output_shapes( .map_err(ShapeInferenceError::MissingAttribute)?; // All input shapes must be the same except for the dimension at the specified axis - let mut shape: Vec = input_shapes[0].dims.iter().map(|x| *x as i64).collect(); + let mut shape: Vec = input_shapes[0].dims.clone(); if axis < -(shape.len() as i64) || axis > (shape.len() - 1) as i64 { return Err(ShapeInferenceError::InvalidNode( node.get_name().to_string(), @@ -1066,7 +1079,7 @@ pub(crate) fn infer_output_shapes( } else { axis as usize }; - shape[axis_index] = input_shapes.iter().map(|s| s.dim(axis_index) as i64).sum(); + shape[axis_index] = input_shapes.iter().map(|s| s.dim(axis_index)).sum(); Ok(vec![Shape::from(input_shapes[0].data_type, &shape)]) } @@ -1092,8 +1105,7 @@ pub(crate) fn infer_output_shapes( }; let output_rank = input_shapes[0].rank() + axes.len(); - let mut input_shape: Vec = - input_shapes[0].dims.iter().map(|x| *x as i64).collect(); + let mut input_shape: Vec = input_shapes[0].dims.clone(); for i in axes { let index = if i < 0 { ((output_rank as i64) + i) as usize @@ -1142,8 +1154,11 @@ pub(crate) fn infer_output_shapes( )); } - let element_count = (end[0] - start[0]) / step[0]; - Ok(vec![Shape::from(ScalarType::I64, &[element_count])]) + let element_count: i64 = (end[0] - start[0]) / step[0]; + Ok(vec![Shape::from( + ScalarType::I64, + &[element_count as usize], + )]) } ("Squeeze", num_inputs @ 1..=2, 1) => { @@ -1162,7 +1177,7 @@ pub(crate) fn infer_output_shapes( vec![] }; - let output_shape: Vec = input_shapes[0] + let output_shape: Vec = input_shapes[0] .dims .iter() .enumerate() @@ -1170,7 +1185,7 @@ pub(crate) fn infer_output_shapes( if (has_axes && axes.contains(&(idx as i64))) || (!has_axes && *dim == 1) { vec![] } else { - vec![*dim as i64] + vec![*dim] } }) .collect(); @@ -1179,8 +1194,8 @@ pub(crate) fn infer_output_shapes( } ("Transpose", 1, 1) => { - let input_dims: Vec = input_shapes[0].dims.iter().map(|x| *x as i64).collect(); - let output_dims: Vec = match node.get_attribute_value::>("perm", None) { + let input_dims = &input_shapes[0].dims; + let output_dims: Vec = match node.get_attribute_value::>("perm", None) { Ok(perm) => perm.iter().map(|idx| input_dims[*idx as usize]).collect(), Err(_) => input_dims.iter().rev().cloned().collect(), }; @@ -1270,7 +1285,7 @@ fn process_slice_inputs( fn fix_raw_tensor(tensor: &mut TensorProto) -> Result<(), ShapeInferenceError> { if tensor.has_raw_data() { let raw_data = tensor.take_raw_data(); - match ScalarType::from_i32(tensor.get_data_type()) + match ScalarType::from_onnx_i32(tensor.get_data_type()) .map_err(ShapeInferenceError::UnsupportedDataType)? { ScalarType::F32 => tensor.set_float_data(bytemuck::cast_slice(&raw_data[..]).to_vec()), diff --git a/wonnx-preprocessing/src/text.rs b/wonnx-preprocessing/src/text.rs index 452d1c97..93c48dd2 100644 --- a/wonnx-preprocessing/src/text.rs +++ b/wonnx-preprocessing/src/text.rs @@ -4,7 +4,7 @@ use std::io::{BufRead, BufReader}; use std::path::Path; use thiserror::Error; use tokenizers::{EncodeInput, Encoding, InputSequence, Tokenizer}; -use wonnx::utils::Shape; +use wonnx::tensor::Shape; use crate::Tensor; @@ -83,7 +83,7 @@ impl TextTokenizer { text: &str, shape: &Shape, ) -> Result { - let segment_length = shape.dim(shape.rank() - 1) as usize; + let segment_length = shape.dim(shape.rank() - 1); let tokenized = self.tokenize(text)?; let mut tokens = tokenized.get_mask(); tokens.resize(segment_length, 0); @@ -92,7 +92,7 @@ impl TextTokenizer { } pub fn get_input_for(&self, text: &str, shape: &Shape) -> Result { - let segment_length = shape.dim(shape.rank() - 1) as usize; + let segment_length = shape.dim(shape.rank() - 1); let tokenized = self.tokenize(text)?; let mut tokens = tokenized.get_tokens(); tokens.resize(segment_length, 0); diff --git a/wonnx-preprocessing/src/utils.rs b/wonnx-preprocessing/src/utils.rs new file mode 100644 index 00000000..7b269ca9 --- /dev/null +++ b/wonnx-preprocessing/src/utils.rs @@ -0,0 +1,51 @@ +//! Various utilities to deal with the ONNX format structure +use std::convert::Into; +use thiserror::Error; +use wonnx::onnx; + +#[derive(Error, Debug)] +#[error("did not find attribute '{attribute}' for node '{node_name}'")] +pub struct AttributeNotFoundError { + attribute: String, + node_name: String, +} + +pub trait NodeAttributes { + fn has_attribute(&self, attribute_name: &str) -> bool; + fn get_attribute_value<'a, T: std::convert::From<&'a onnx::AttributeProto>>( + &'a self, + attribute: &str, + default: Option, + ) -> Result; +} + +impl NodeAttributes for onnx::NodeProto { + fn has_attribute(&self, attribute_name: &str) -> bool { + self.get_attribute() + .iter() + .any(|attr| attr.get_name() == attribute_name) + } + + fn get_attribute_value<'a, T>( + &'a self, + attribute: &str, + default: Option, + ) -> Result + where + T: From<&'a onnx::AttributeProto>, + { + match ( + self.get_attribute() + .iter() + .find(|attr| attr.get_name() == attribute), + default, + ) { + (Some(attr), _) => Ok(attr.into()), + (None, Some(default_attr)) => Ok(default_attr), + (None, None) => Err(AttributeNotFoundError { + attribute: attribute.to_string(), + node_name: self.get_name().to_string(), + }), + } + } +} diff --git a/wonnx-py/src/lib.rs b/wonnx-py/src/lib.rs index eb9af8dd..bad51a20 100644 --- a/wonnx-py/src/lib.rs +++ b/wonnx-py/src/lib.rs @@ -1,6 +1,7 @@ -use ::wonnx::utils::OutputTensor; +use ::wonnx::tensor::TensorData; use pyo3::prelude::*; use pyo3::types::PyDict; +use std::borrow::Cow; use std::collections::HashMap; use ::wonnx::Session; @@ -10,15 +11,15 @@ pub struct PySession { pub session: Session, } -pub struct PyOutputTensor(OutputTensor); +pub struct PyOutputTensor(TensorData<'static>); impl IntoPy for PyOutputTensor { fn into_py(self, py: Python) -> PyObject { match self.0 { - OutputTensor::F32(fs) => fs.into_py(py), - OutputTensor::I32(fs) => fs.into_py(py), - OutputTensor::I64(fs) => fs.into_py(py), - OutputTensor::U8(fs) => fs.into_py(py), + TensorData::F32(fs) => fs.into_owned().into_py(py), + TensorData::I32(fs) => fs.into_owned().into_py(py), + TensorData::I64(fs) => fs.into_owned().into_py(py), + TensorData::U8(fs) => fs.into_py(py), } } } @@ -40,8 +41,8 @@ impl PySession { pub fn run(&self, dict: &PyDict) -> PyResult> { let map: HashMap> = dict.extract().unwrap(); let mut inputs = HashMap::new(); - for (key, value) in map.iter() { - inputs.insert(key.clone(), value.as_slice().into()); + for (key, value) in map.into_iter() { + inputs.insert(key.clone(), TensorData::F32(Cow::Owned(value))); } let result = pollster::block_on(self.session.run(&inputs)).unwrap(); Ok(result diff --git a/wonnx-wasm/src/lib.rs b/wonnx-wasm/src/lib.rs index 85505d1d..b02a7ed4 100644 --- a/wonnx-wasm/src/lib.rs +++ b/wonnx-wasm/src/lib.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use std::sync::Arc; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::future_to_promise; -use wonnx::utils::{InputTensor, OutputTensor}; +use wonnx::tensor::TensorData; #[wasm_bindgen(start)] pub fn main() { @@ -78,24 +78,23 @@ impl Session { let engine = self.session.clone(); future_to_promise(async move { - let input_data: HashMap> = input_copy + let input_data: HashMap> = input_copy .input_data .iter() .map(|(k, v)| (k.clone(), v.as_slice().into())) .collect(); let result = engine.run(&input_data).await.map_err(SessionError)?; - drop(input_copy); Ok(serde_wasm_bindgen::to_value(&result).unwrap()) }) } } /// Convert an OutputTensor to a JsValue (we cannot implement Into for OutputTensor here) -pub fn tensor_to_js_value(tensor: OutputTensor) -> JsValue { +pub fn tensor_to_js_value(tensor: TensorData) -> JsValue { match tensor { - OutputTensor::F32(fs) => serde_wasm_bindgen::to_value(&fs).unwrap(), - OutputTensor::I32(ints) => serde_wasm_bindgen::to_value(&ints).unwrap(), - OutputTensor::I64(ints) => serde_wasm_bindgen::to_value(&ints).unwrap(), - OutputTensor::U8(ints) => serde_wasm_bindgen::to_value(&ints).unwrap(), + TensorData::F32(fs) => serde_wasm_bindgen::to_value(&fs).unwrap(), + TensorData::I32(ints) => serde_wasm_bindgen::to_value(&ints).unwrap(), + TensorData::I64(ints) => serde_wasm_bindgen::to_value(&ints).unwrap(), + TensorData::U8(ints) => serde_wasm_bindgen::to_value(&ints).unwrap(), } } diff --git a/wonnx/examples/mnist.rs b/wonnx/examples/mnist.rs index 269a1ad5..61f7245a 100644 --- a/wonnx/examples/mnist.rs +++ b/wonnx/examples/mnist.rs @@ -4,7 +4,7 @@ use std::convert::TryInto; use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb}; use std::path::Path; use std::time::Instant; -use wonnx::utils::OutputTensor; +use wonnx::tensor::TensorData; // Args Management async fn run() { @@ -22,7 +22,7 @@ async fn run() { } // Hardware management -async fn execute_gpu() -> Option> { +async fn execute_gpu<'a>() -> Option>> { let mut input_data = HashMap::new(); let image = load_image(); diff --git a/wonnx/examples/simple_graph.rs b/wonnx/examples/simple_graph.rs index d8bccf10..cfbf786a 100644 --- a/wonnx/examples/simple_graph.rs +++ b/wonnx/examples/simple_graph.rs @@ -1,70 +1,42 @@ use std::collections::HashMap; +use wonnx::builder::*; -use wonnx::{ - utils::{attribute, graph, initializer, model, node, tensor, OutputTensor}, - SessionError, WonnxError, -}; +fn main() { + env_logger::init(); + pollster::block_on(run()).unwrap(); +} async fn run() -> Result<(), WonnxError> { let result = execute_gpu().await?; let result = result.into_iter().next().unwrap().1; + println!("{:#?}", result); + assert_eq!( result, - OutputTensor::F32(vec![54., 63., 72., 99., 108., 117., 144., 153., 162.]) + TensorData::F32(vec![54., 63., 72., 99., 108., 117., 144., 153., 162.].into()) ); Ok(()) } -// Hardware management -async fn execute_gpu() -> Result, SessionError> { - // USER INPUT +async fn execute_gpu() -> Result>, SessionError> { + // Hyperparameters let n = 5; let c = 1; - let mut input_data = HashMap::new(); - - let data: Vec = (0..25).map(|x| x as f32).collect(); - input_data.insert("X".to_string(), data.as_slice().into()); - - // ONNX INPUTS - let shape = vec![1, c, n as i64, n as i64]; let kernel_n = 3; let m = 1; - let data_w: Vec = (0..m * c * kernel_n * kernel_n).map(|_| 1.0f32).collect(); - let model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &[1, m, 3, 3])], - vec![], - vec![initializer("W", data_w, vec![m, c, 3, 3])], - vec![node( - vec!["X", "W"], - vec!["Y"], - "conv", - "Conv", - vec![attribute("kernel_shape", vec![3, 3])], - )], - )); - - // LOGIC - let session = wonnx::Session::from_model(model) - .await - .expect("Session did not create"); + let data_w: Vec = (0..m * c * kernel_n * kernel_n).map(|_| 1.0f32).collect(); - session.run(&input_data).await -} + let session = { + let input_x = input("X", ScalarType::F32, &[1, c, n, n]); + let weights = tensor("W", &[m, c, 3, 3], data_w.into()); + let conv = input_x.conv(&weights, &[3, 3], &[m, c, 3, 3]); + session_for_outputs(&["Y"], &[conv], 13).await? + }; -// #[wasm_bindgen_test] -fn main() { - #[cfg(not(target_arch = "wasm32"))] - { - env_logger::init(); - pollster::block_on(run()).unwrap(); - } - #[cfg(target_arch = "wasm32")] - { - std::panic::set_hook(Box::new(console_error_panic_hook::hook)); - console_log::init().expect("could not initialize logger"); - wasm_bindgen_futures::spawn_local(run()); - } + let mut input_data = HashMap::new(); + let data: Vec = (0..25).map(|x| x as f32).collect(); + input_data.insert("X".to_string(), data.as_slice().into()); + Ok(session.run(&input_data).await?.to_owned()) } diff --git a/wonnx/examples/squeeze.rs b/wonnx/examples/squeeze.rs index 159a184c..0f26b90a 100644 --- a/wonnx/examples/squeeze.rs +++ b/wonnx/examples/squeeze.rs @@ -9,7 +9,7 @@ use std::{ io::{BufRead, BufReader}, path::Path, }; -use wonnx::utils::OutputTensor; +use wonnx::tensor::TensorData; use wonnx::WonnxError; // Args Management @@ -31,7 +31,7 @@ async fn run() { } // Hardware management -async fn execute_gpu() -> Result, WonnxError> { +async fn execute_gpu<'a>() -> Result>, WonnxError> { let mut input_data = HashMap::new(); let image = load_image(); input_data.insert("data".to_string(), image.as_slice().unwrap().into()); diff --git a/wonnx/src/builder.rs b/wonnx/src/builder.rs new file mode 100644 index 00000000..082cd65c --- /dev/null +++ b/wonnx/src/builder.rs @@ -0,0 +1,193 @@ +use crate::{ + gpu::GpuModel, + ir::{Input, Node, NodeDefinition, OperatorDefinition, Tensor}, + optimizer::Optimizer, + resource, Session, +}; +use std::sync::Arc; + +pub use crate::tensor::{ScalarType, Shape, TensorData}; +pub use crate::{SessionError, WonnxError}; + +#[derive(Clone)] +pub struct TensorRef<'model> { + node: Arc>, + output_index: usize, + output_shape: Shape, +} + +impl<'model> From<&TensorRef<'model>> for Input<'model> { + fn from(val: &TensorRef<'model>) -> Self { + Input { + source_node: val.node.clone(), + output_index: val.output_index, + } + } +} + +impl<'model> TensorRef<'model> { + /// Element-wise addition of two tensors (ONNX Add operator) + pub fn add(&self, rhs: &Self) -> Self { + assert_eq!(self.output_shape, rhs.output_shape); + self.binary_op(rhs, "Add", self.output_shape.clone()) + } + + /// Element-wise negation of this tensor (ONNX Neg operator) + pub fn neg(&self) -> Self { + self.unary_mapping_op("Neg") + } + + fn unary_mapping_op(&self, op_type: &str) -> Self { + let op_def = OperatorDefinition::new( + op_type, + vec![self.output_shape.clone()], + format!("{}_{}", self.node.definition().get_display_name(), op_type), + ); + TensorRef { + node: Arc::new(Node::new( + NodeDefinition::Operator(op_def), + vec![Input { + source_node: self.node.clone(), + output_index: 0, + }], + )), + output_index: 0, + output_shape: self.output_shape.clone(), + } + } + + /// Convolution (ONNX Conv operator) + pub fn conv(&self, weights: &Self, kernel_shape: &[usize], output_dims: &[usize]) -> Self { + let output_shape = Shape::from(self.output_shape.data_type, output_dims); + let mut op = OperatorDefinition::new( + "Conv", + vec![output_shape.clone()], + format!("{}_conv", self.node.definition.get_display_name()), + ); + op.set_attribute( + "kernel_shape", + kernel_shape.iter().map(|x| *x as i64).collect::>(), + ); + TensorRef { + node: Arc::new(Node::new( + NodeDefinition::Operator(op), + vec![self.into(), weights.into()], + )), + output_index: 0, + output_shape, + } + } + + fn binary_op(&self, rhs: &Self, op_type: &str, output_shape: Shape) -> Self { + let def = NodeDefinition::Operator(OperatorDefinition::new( + op_type, + vec![output_shape.clone()], + format!( + "{}_{}_{}", + self.node.definition().get_display_name(), + rhs.node.definition().get_display_name(), + op_type + ), + )); + + TensorRef { + node: Arc::new(Node::new(def, vec![self.into(), rhs.into()])), + output_index: 0, + output_shape, + } + } +} + +/// Create a tensor reference representing input supplied at inference time. +pub fn input<'model, S: ToString>( + name: S, + scalar_type: ScalarType, + dims: &[usize], +) -> TensorRef<'model> { + let shape = Shape::from(scalar_type, dims); + TensorRef { + node: Arc::new(Node { + inputs: vec![], + definition: NodeDefinition::Input { + name: name.to_string(), + shape: shape.clone(), + }, + }), + output_index: 0, + output_shape: shape, + } +} + +/// Create a tensor reference containing static data (included in the model). +pub fn tensor<'model, S: ToString>( + name: S, + dims: &[usize], + data: TensorData<'model>, +) -> TensorRef<'model> { + let output_shape = Shape::from(data.scalar_type(), dims); + TensorRef { + node: Arc::new(Node { + inputs: vec![], + definition: NodeDefinition::Tensor(Tensor { + data, + dims: dims.to_vec(), + display_name: name.to_string(), + }), + }), + output_index: 0, + output_shape, + } +} + +/// Start an inference session for calculating the outputs provided. The names will be used as keys in the resulting hashmap +pub async fn session_for_outputs<'model, S: ToString>( + output_names: &[S], + outputs: &[TensorRef<'model>], + onnx_opset_version: i64, +) -> Result { + let output_names = output_names.iter().map(|x| x.to_string()).collect(); + + let outputs = Arc::new(Node::new( + NodeDefinition::Outputs { + names: output_names, + }, + outputs + .iter() + .map(|x| Input { + source_node: x.node.clone(), + output_index: x.output_index, + }) + .collect(), + )); + + let (device, queue) = resource::request_device_queue().await; + let mut optimizer = Optimizer::new(onnx_opset_version); + let ir = optimizer.optimize(outputs).await?; + let gpu_model = GpuModel::from(ir, device, queue, onnx_opset_version)?; + Ok(Session { gpu_model }) +} + +#[cfg(test)] +mod tests { + use crate::{ + builder::{session_for_outputs, tensor}, + tensor::TensorData, + }; + use std::collections::HashMap; + + #[test] + pub fn test_builder() { + let _ = env_logger::builder().is_test(true).try_init(); + pollster::block_on(async { + let a = tensor("x", &[1, 3], vec![0.1, 0.2, 0.3].into()); + let b = tensor("y", &[1, 3], vec![3.0, 2.0, 1.0].into()); + let axb = a.add(&b); + let sesh = session_for_outputs(&["result"], &[axb], 13).await.unwrap(); + let result = sesh.run(&HashMap::new()).await.unwrap(); + assert_eq!( + result["result"], + TensorData::F32(vec![3.1, 2.2, 1.3].into()) + ) + }); + } +} diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index bed2b865..6b311192 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -1,6 +1,7 @@ //! Compiles individual ONNX ops to a WebGPU shader using WGSL templates -use crate::utils::{ - ceil, AttributeNotFoundError, DataTypeError, MultiType, NodeAttributes, ScalarType, Shape, +use crate::{ + ir::{AttributeNotFoundError, OperatorDefinition}, + tensor::{ceil, DataTypeError, MultiType, ScalarType, Shape}, }; use num::integer::gcd; use tera::{Context, Tera}; @@ -8,12 +9,12 @@ use thiserror::Error; /// The maximum number of threads that can be spawned in each dimension, according to the WebGPU specification. See /// -pub const MAX_COMPUTE_WORKGROUPS_PER_DIMENSION: u32 = 65535; +pub const MAX_COMPUTE_WORKGROUPS_PER_DIMENSION: usize = 65535; /// The maximum workgroup size per dimension (see ) -pub const MAX_WORKGROUP_SIZE_X: u32 = 256; -pub const MAX_WORKGROUP_SIZE_Y: u32 = 256; -// pub const MAX_WORKGROUP_SIZE_Z: u32 = 64; +pub const MAX_WORKGROUP_SIZE_X: usize = 256; +pub const MAX_WORKGROUP_SIZE_Y: usize = 256; +// pub const MAX_WORKGROUP_SIZE_Z: usize = 64; lazy_static! { // Templates for shader source code that we generate for nodes @@ -141,7 +142,7 @@ lazy_static! { pub struct CompiledNode { pub shader: String, - pub threads: (u32, u32, u32), + pub threads: (usize, usize, usize), } #[derive(Error, Debug)] @@ -178,7 +179,7 @@ pub enum CompileError { }, #[error("the model exceeds the limit for {0}: {1} > {2}")] - ComputeLimitExceeded(String, u32, u32), + ComputeLimitExceeded(String, usize, usize), #[error("cannot determine data type to use: {0} or {1}")] TypesDisagree(ScalarType, ScalarType), @@ -202,7 +203,7 @@ pub enum CompileError { struct NodeTemplate { scalar_type: ScalarType, template: &'static str, - threads: (u32, u32, u32), + threads: (usize, usize, usize), } /// Returns the data type of the input and output shapes, but error if these types differ or when no input/output was specified @@ -244,7 +245,7 @@ fn agreed_type( } pub fn compile( - node: &crate::onnx::NodeProto, + op_def: &OperatorDefinition, input_shapes: &[&Shape], output_shapes: &[&Shape], opset_version: i64, @@ -252,17 +253,17 @@ pub fn compile( let input_lengths = input_shapes .iter() .map(|shape| shape.element_count()) - .collect::>(); + .collect::>(); let output_lengths = output_shapes .iter() .map(|shape| shape.element_count()) - .collect::>(); + .collect::>(); - let input_chunks: Vec> = input_shapes.iter().map(|d| d.chunks()).collect(); - let output_chunks: Vec> = output_shapes.iter().map(|d| d.chunks()).collect(); - let i_dims: Vec<&Vec> = input_shapes.iter().map(|s| &s.dims).collect(); - let o_dims: Vec<&Vec> = output_shapes.iter().map(|s| &s.dims).collect(); + let input_chunks: Vec> = input_shapes.iter().map(|d| d.chunks()).collect(); + let output_chunks: Vec> = output_shapes.iter().map(|d| d.chunks()).collect(); + let i_dims: Vec<&Vec> = input_shapes.iter().map(|s| &s.dims).collect(); + let o_dims: Vec<&Vec> = output_shapes.iter().map(|s| &s.dims).collect(); let mut context = Context::new(); context.insert("i_lens", &input_lengths); @@ -271,10 +272,10 @@ pub fn compile( context.insert("o_shape", &o_dims); context.insert("i_chunks", &input_chunks); context.insert("o_chunks", &output_chunks); - context.insert("op_type", &node.get_op_type()); + context.insert("op_type", &op_def.get_op_type()); context.insert("opset_version", &opset_version); - let node_template: NodeTemplate = match node.get_op_type() { + let node_template: NodeTemplate = match op_def.get_op_type() { op @ ("Reshape" | "Dropout" | "Identity" | "Flatten" | "Squeeze" | "Unsqueeze") => { // These ops should all be optimized away earlier return Err(CompileError::InvalidOperation(op.to_string())); @@ -301,7 +302,7 @@ pub fn compile( | "ReduceL1" | "ReduceL2" | "ReduceLogSum" | "ReduceLogSumExp" | "ReduceSumSquare") => { let all_axes: Vec = (0..(i_dims[0].len() as i64)).collect(); - let axes: Vec = node + let axes: Vec = op_def .get_attribute_value("axes", Some(all_axes))? .into_iter() .map(|idx| { @@ -314,7 +315,7 @@ pub fn compile( .collect(); let scalar_type = agreed_type(&[input_shapes[0]], output_shapes)?; - let dims_removed: Vec = input_shapes[0] + let dims_removed: Vec = input_shapes[0] .dims .iter() .enumerate() @@ -322,7 +323,7 @@ pub fn compile( if axes.contains(&(idx as i64)) { 1 } else { - *dim as i64 + *dim } }) .collect(); @@ -357,7 +358,7 @@ pub fn compile( "OneHot" => { // Currently only OneHot on the last axis is supported - let axis = node.get_attribute_value("axis", Some(-1))?; + let axis = op_def.get_attribute_value("axis", Some(-1))?; if axis != -1 { return Err(CompileError::UnimplementedVariant { variant: format!("axis={}", axis), @@ -401,7 +402,7 @@ pub fn compile( // Input 0 is data, input 1 is indices // Which axis to gather on. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data). // Default is 0. See https://github.com/onnx/onnx/blob/main/docs/Operators.md#attributes-25 - let axis = node.get_attribute_value("axis", Some(0))?; + let axis = op_def.get_attribute_value("axis", Some(0))?; if axis != 0 { return Err(CompileError::UnimplementedVariant { variant: format!("axis={}", axis), @@ -411,7 +412,7 @@ pub fn compile( let elements_per_index = input_chunks[0][0]; let scalar_type = agreed_type(&input_shapes[0..1], output_shapes)?; - let chunk_type = MultiType::for_size(elements_per_index as usize, scalar_type); + let chunk_type = MultiType::for_size(elements_per_index, scalar_type); let chunk_size = chunk_type.elements(); // The X dimension represents the indexes @@ -423,7 +424,7 @@ pub fn compile( // The Y dimension represents the elements to copy for each index let (y_threads, workgroup_size_y) = workgroup_size( - ceil(elements_per_index, chunk_size as u64), + ceil(elements_per_index, chunk_size), MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, MAX_WORKGROUP_SIZE_Y, )?; @@ -442,7 +443,7 @@ pub fn compile( "Cast" => { let cast_to_type = - ScalarType::from_i32(node.get_attribute_value::("to", None)? as i32)?; + ScalarType::from_onnx_i32(op_def.get_attribute_value::("to", None)? as i32)?; if !cast_to_type.wgsl_supported() { return Err(CompileError::UnimplementedVariant { @@ -476,7 +477,7 @@ pub fn compile( /* Describes the axis of the inputs when coerced to 2D; defaults to one because the 0th axis most likely describes the batch_size. From version 13 onwards, counting backwards is also allowed. */ - let mut axis = node.get_attribute_value("axis", Some(default_axis))?; + let mut axis = op_def.get_attribute_value("axis", Some(default_axis))?; if axis < 0 { if opset_version >= 13 { axis += input_shapes[0].rank() as i64; @@ -497,15 +498,11 @@ pub fn compile( }); } - let left_of_axis = input_shapes[0].dims[0..(axis as usize)] - .iter() - .product::(); - let axis_chunk = input_shapes[0].dims[(axis as usize)..] - .iter() - .product::(); + let left_of_axis = input_shapes[0].dims[0..(axis as usize)].iter().product(); + let axis_chunk: usize = input_shapes[0].dims[(axis as usize)..].iter().product(); let right_of_axis_chunk = input_shapes[0].dims[((axis + 1) as usize)..] .iter() - .product::(); + .product(); context.insert("axis_chunk", &axis_chunk); @@ -545,7 +542,7 @@ pub fn compile( // Arithmetic operation op @ ("Add" | "And" | "Div" | "Equal" | "Greater" | "GreaterOrEqual" | "Less" | "LessOrEqual" | "Mod" | "Mul" | "Or" | "Sub" | "Pow" | "PRelu") => { - let broadcast = node.get_attribute_value("broadcast", Some(0))?; + let broadcast = op_def.get_attribute_value("broadcast", Some(0))?; if broadcast != 0 { return Err(CompileError::UnimplementedVariant { op: op.to_string(), @@ -573,7 +570,7 @@ pub fn compile( "PRelu" => "PRelu", _ => { return Err(CompileError::UnimplementedOp( - node.get_op_type().to_string(), + op_def.get_op_type().to_string(), )) } }, @@ -637,7 +634,7 @@ pub fn compile( }); } else { // Not broadcasting - let coefficient = node.get_attribute_value("coefficient", Some(1.0))?; + let coefficient = op_def.get_attribute_value("coefficient", Some(1.0))?; context.insert("coefficient", &coefficient); let (x_threads, workgroup_size_x) = workgroup_size( @@ -659,7 +656,7 @@ pub fn compile( /* Prior to version 9, BatchNormalization supported a 'spatial' mode where input mean/variance are of shape [C,W,H] instead of just [C]. See https://github.com/onnx/onnx/blob/master/docs/Changelog.md#BatchNormalization-7. This mode is not supported. */ - if let Ok(spatial_value) = node.get_attribute_value::("spatial", None) { + if let Ok(spatial_value) = op_def.get_attribute_value::("spatial", None) { if opset_version < 9 { return Err(CompileError::UnimplementedVariant { op: "BatchNormalization".to_string(), @@ -708,31 +705,28 @@ pub fn compile( } // If w*h is a multiple of 4, we can use vec4 in our shader - let elem_type = MultiType::for_size((input_w * input_h) as usize, ScalarType::F32); + let elem_type = MultiType::for_size(input_w * input_h, ScalarType::F32); context.insert("elem_type", &elem_type.wgsl_type_name()); context.insert("elem_stride", &elem_type.stride()); // The default for epsilon is 1e05, see https://github.com/onnx/onnx/blob/master/docs/Changelog.md#attributes-252 - let epsilon = node.get_attribute_value("epsilon", Some(1e-05))?; + let epsilon = op_def.get_attribute_value("epsilon", Some(1e-05))?; context.insert("epsilon", &epsilon); context.insert( "batch_size", - &ceil( - input_channels * input_w * input_h, - elem_type.elements() as u64, - ), + &ceil(input_channels * input_w * input_h, elem_type.elements()), ); context.insert( "channel_size", - &ceil(input_w * input_h, elem_type.elements() as u64), + &ceil(input_w * input_h, elem_type.elements()), ); NodeTemplate { scalar_type: agreed_type(&input_shapes[0..1], &output_shapes[0..1])?, template: "endomorphism/batchnormalization.wgsl", threads: ( - ceil(input_w * input_h, elem_type.elements() as u64) as _, + ceil(input_w * input_h, elem_type.elements()) as _, input_channels as _, input_batches as _, ), @@ -741,16 +735,16 @@ pub fn compile( op @ ("Relu" | "Sigmoid" | "Softsign" | "Softplus" | "Clip" | "Celu" | "Elu" | "LeakyRelu") => { let alpha = if op == "LeakyRelu" { - node.get_attribute_value("alpha", Some(0.01))? + op_def.get_attribute_value("alpha", Some(0.01))? } else { - node.get_attribute_value("alpha", Some(1.0))? + op_def.get_attribute_value("alpha", Some(1.0))? }; context.insert("alpha", &alpha); if op == "Clip" { let min: Vec = - node.get_attribute_value("min", Some(vec![f32::NEG_INFINITY]))?; - let max: Vec = node.get_attribute_value("max", Some(vec![f32::INFINITY]))?; + op_def.get_attribute_value("min", Some(vec![f32::NEG_INFINITY]))?; + let max: Vec = op_def.get_attribute_value("max", Some(vec![f32::INFINITY]))?; if min.len() != 1 { return Err(CompileError::InvalidAttributeValue { attribute: "min".into(), @@ -795,7 +789,7 @@ pub fn compile( NodeTemplate { scalar_type: agreed_type(input_shapes, output_shapes)?, template: "matrix/concat.wgsl", - threads: (ceil(output_lengths[0], 256) as u32, 1, 1), + threads: (ceil(output_lengths[0], 256), 1, 1), } } op @ ("MaxPool" | "AveragePool" | "Conv" | "ConvRelu" | "ConvLeakyRelu" | "ConvMish" @@ -817,17 +811,17 @@ pub fn compile( context.insert("op_type", "AveragePool"); } - let auto_pad = node.get_attribute_value("auto_pad", Some("NOTSET".to_string()))?; - let dilations = node.get_attribute_value("dilations", Some(vec![1, 1]))?; + let auto_pad = op_def.get_attribute_value("auto_pad", Some("NOTSET".to_string()))?; + let dilations = op_def.get_attribute_value("dilations", Some(vec![1, 1]))?; let kernel_shape = if is_global_average_pool { vec![input_shapes[0].dim(2) as i64, input_shapes[0].dim(3) as i64] } else { - node.get_attribute_value::>("kernel_shape", None)? + op_def.get_attribute_value::>("kernel_shape", None)? }; - let strides = node.get_attribute_value("strides", Some(vec![1, 1]))?; - let pads = node.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?; - let count_include_pad = node.get_attribute_value("count_include_pad", Some(0))?; - let group = node.get_attribute_value("group", Some(1))? as u64; + let strides = op_def.get_attribute_value("strides", Some(vec![1, 1]))?; + let pads = op_def.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?; + let count_include_pad = op_def.get_attribute_value("count_include_pad", Some(0))?; + let group = op_def.get_attribute_value("group", Some(1))? as usize; let pads = match auto_pad.as_str() { "NOTSET" => pads.to_vec(), @@ -902,7 +896,7 @@ pub fn compile( context.insert("kernel_length", &(kernel_shape[0] * kernel_shape[1])); context.insert( "kernel_channel_len", - &((kernel_shape[0] as u64) * (kernel_shape[1] as u64) * channels_per_group), + &((kernel_shape[0]) * (kernel_shape[1]) * channels_per_group as i64), ); context.insert("pad", &pads); context.insert("count_include_pad", &count_include_pad); @@ -927,7 +921,7 @@ pub fn compile( } "Conv" | "ConvRelu" | "ConvLeakyRelu" | "ConvMish" => { // Alpha is the Leaky Relu attribute - let alpha = node.get_attribute_value("alpha", Some(0.01))?; + let alpha = op_def.get_attribute_value("alpha", Some(0.01))?; context.insert("alpha", &alpha); let scalar_type = agreed_type(input_shapes, output_shapes)?; @@ -983,9 +977,9 @@ pub fn compile( let mut input_left_shape = input_shapes[0].clone(); let mut input_right_shape = input_shapes[1].clone(); let mut output_shape = output_shapes[0].clone(); - let mut stack_left_stride: u64 = 0; - let mut stack_right_stride: u64 = 0; - let mut stack_output_stride: u64 = 0; + let mut stack_left_stride: usize = 0; + let mut stack_right_stride: usize = 0; + let mut stack_output_stride: usize = 0; if op == "MatMul" { // - If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes @@ -1065,9 +1059,9 @@ pub fn compile( if op == "Gemm" { // Check if A resp. B should be transposed, or C should be broadcast (default: 0 = false) - let transpose_a = node.get_attribute_value("transA", Some(0))?; - let transpose_b = node.get_attribute_value("transB", Some(0))?; - let broadcast = node.get_attribute_value("broadcast", Some(0))?; + let transpose_a = op_def.get_attribute_value("transA", Some(0))?; + let transpose_b = op_def.get_attribute_value("transB", Some(0))?; + let broadcast = op_def.get_attribute_value("broadcast", Some(0))?; if transpose_a != 0 || transpose_b != 0 || broadcast != 0 { return Err(CompileError::UnimplementedVariant { @@ -1129,8 +1123,8 @@ pub fn compile( } // Obtain alpha and beta coefficients - let alpha = node.get_attribute_value("alpha", Some(1.0))?; - let beta = node.get_attribute_value("beta", Some(1.0))?; + let alpha = op_def.get_attribute_value("alpha", Some(1.0))?; + let beta = op_def.get_attribute_value("beta", Some(1.0))?; context.insert("alpha", &alpha); context.insert("beta", &beta); @@ -1189,7 +1183,7 @@ pub fn compile( } } "Resize" => { - let coordinate_transformation_mode = node.get_attribute_value( + let coordinate_transformation_mode = op_def.get_attribute_value( "coordinate_transformation_mode", Some("half_pixel".to_string()), )?; @@ -1204,9 +1198,9 @@ pub fn compile( "align_corners" => {} "asymmetric" => {} "tf_crop_and_resize" => { - let roi = node.get_attribute_value::>("roi", None)?; + let roi = op_def.get_attribute_value::>("roi", None)?; let extrapolation_value = - node.get_attribute_value("extrapolation_value", Some(0.0))?; + op_def.get_attribute_value("extrapolation_value", Some(0.0))?; context.insert("roi", &roi); context.insert("extrapolation_value", &extrapolation_value); } @@ -1221,9 +1215,9 @@ pub fn compile( } } - let scales = node.get_attribute_value::>("scales", Some(vec![]))?; + let scales = op_def.get_attribute_value::>("scales", Some(vec![]))?; let scale_prints = if scales.is_empty() { - let sizes = node.get_attribute_value::>("sizes", Some(vec![]))?; + let sizes = op_def.get_attribute_value::>("sizes", Some(vec![]))?; sizes .iter() .enumerate() @@ -1236,13 +1230,13 @@ pub fn compile( scales.iter().map(|x| format!("{:.2}", x)).collect() }; - let mode = node.get_attribute_value("mode", Some("nearest".to_string()))?; + let mode = op_def.get_attribute_value("mode", Some("nearest".to_string()))?; context.insert("mode", &mode); context.insert("scales", &scale_prints); match mode.as_str() { "nearest" => { - let nearest_mode = node.get_attribute_value( + let nearest_mode = op_def.get_attribute_value( "nearest_mode", Some("round_prefer_floor".to_string()), )?; @@ -1257,7 +1251,7 @@ pub fn compile( } } "cubic" => { - let cubic_coeff_a = node.get_attribute_value("cubic_coeff_a", Some(-0.75))?; + let cubic_coeff_a = op_def.get_attribute_value("cubic_coeff_a", Some(-0.75))?; context.insert("cubic_coeff_a", &cubic_coeff_a); return Err(CompileError::UnimplementedVariant { op: String::from("Resize"), @@ -1273,39 +1267,39 @@ pub fn compile( } }; - let exclude_outside = node.get_attribute_value("exclude_outside", Some(0))?; + let exclude_outside = op_def.get_attribute_value("exclude_outside", Some(0))?; context.insert("exclude_outside", &exclude_outside); NodeTemplate { scalar_type: agreed_type(&input_shapes[0..1], &output_shapes[0..1])?, template: "matrix/resize.wgsl", - threads: (ceil(output_lengths[0], 256) as u32, 1, 1), + threads: (ceil(output_lengths[0], 256), 1, 1), } } "Sum" => return Err(CompileError::UnimplementedOp(String::from("Sum"))), "Split" => { - let mut axis = node.get_attribute_value("axis", Some(0))?; + let mut axis = op_def.get_attribute_value("axis", Some(0))?; if axis < 0 { axis += input_shapes[0].rank() as i64 } context.insert("axis", &axis); - let split_chunk = input_shapes[0].dim(axis as usize) as usize / output_shapes.len(); + let split_chunk = input_shapes[0].dim(axis as usize) / output_shapes.len(); let default_split = (1..=output_shapes.len()) .map(|x| (x * split_chunk) as _) .collect(); - let split = node.get_attribute_value::>("split", Some(default_split))?; + let split = op_def.get_attribute_value::>("split", Some(default_split))?; context.insert("split", &split); NodeTemplate { scalar_type: agreed_type(&input_shapes[0..1], output_shapes)?, template: "matrix/split.wgsl", - threads: (ceil(input_lengths[0], 256) as u32, 1, 1), + threads: (ceil(input_lengths[0], 256), 1, 1), } } "Pad" => { - let mode = node.get_attribute_value("mode", Some("constant".to_string()))?; + let mode = op_def.get_attribute_value("mode", Some("constant".to_string()))?; match mode.as_str() { "constant" => {} _ => { @@ -1316,7 +1310,7 @@ pub fn compile( } } - let pads: Vec = node.get_attribute_value("pads", None)?; + let pads: Vec = op_def.get_attribute_value("pads", None)?; if pads.len() != input_shapes[0].rank() * 2 { return Err(CompileError::InvalidAttributeValue { attribute: "pads".into(), @@ -1324,7 +1318,7 @@ pub fn compile( opset_version, }); } - let constant_value = node.get_attribute_value("constant_value", Some(0.0))?; + let constant_value = op_def.get_attribute_value("constant_value", Some(0.0))?; context.insert("constant_value", &constant_value); #[derive(serde::Serialize)] @@ -1339,7 +1333,7 @@ pub fn compile( pad_info.push(PadInfo { copy_start: begin as _, - end_pad_start: input_shapes[0].dim(axis) + begin as u64 - end as u64, + end_pad_start: input_shapes[0].dim(axis) as u64 + begin as u64 - end as u64, }); } context.insert("pad_info", &pad_info); @@ -1347,13 +1341,13 @@ pub fn compile( NodeTemplate { scalar_type: agreed_type(&input_shapes[0..1], &output_shapes[0..1])?, template: "matrix/pad.wgsl", - threads: (ceil(output_lengths[0], 256) as u32, 1, 1), + threads: (ceil(output_lengths[0], 256), 1, 1), } } "Transpose" => { let n_dims: i64 = input_shapes[0].rank() as i64; let default = (0..n_dims).rev().collect::>(); - let perms: Vec = node.get_attribute_value("perm", Some(default))?; + let perms: Vec = op_def.get_attribute_value("perm", Some(default))?; // The number of elements in the permutations list must be equal to the output shape rank if perms.len() != output_shapes[0].rank() { @@ -1366,12 +1360,8 @@ pub fn compile( let chunks = perms .iter() - .map(|p| { - input_shapes[0].dims[((*p as usize) + 1)..] - .iter() - .product::() - }) - .collect::>(); + .map(|p| input_shapes[0].dims[((*p as usize) + 1)..].iter().product()) + .collect::>(); context.insert("permuted_chunks", &chunks); @@ -1437,15 +1427,15 @@ pub fn compile( /// Determines the appropriate number of threads and workgroup size given a number of times the entry point of the shader should be run fn workgroup_size( - x: u64, - max_threads: u32, - max_workgroup_size: u32, -) -> Result<(u32, u32), CompileError> { - let max_x = max_threads as u64; + x: usize, + max_threads: usize, + max_workgroup_size: usize, +) -> Result<(usize, usize), CompileError> { + let max_x = max_threads; Ok(if x > max_x { let workgroup_size = ceil(x, max_x) as _; - let threads = ceil(x, workgroup_size as u64) as _; + let threads = ceil(x, workgroup_size) as _; log::debug!( "number of items ({}) exceeds maximum number of threads ({}); adjusting workgroup size={} and threads={} (this will compute {} items)", x, @@ -1473,6 +1463,6 @@ fn workgroup_size( (threads, workgroup_size) } else { - (x as u32, 1) + (x, 1) }) } diff --git a/wonnx/src/gpu.rs b/wonnx/src/gpu.rs index 9a3a2255..e04bf0b8 100644 --- a/wonnx/src/gpu.rs +++ b/wonnx/src/gpu.rs @@ -15,13 +15,9 @@ use wgpu::{Buffer, BufferAsyncError, BufferUsages, CommandEncoder, Device}; use crate::{ compiler::{compile, CompileError, CompiledNode}, - ir::{Node, NodeDefinition, NodeIdentifier, OperatorDefinition}, - onnx::TensorProto, + ir::{Node, NodeDefinition, NodeIdentifier, OperatorDefinition, Tensor}, resource::{self, resize}, - utils::{ - ceil, DataTypeError, InputTensor, OutputTensor, ScalarType, Shape, - MINIMUM_BUFFER_SIZE_BYTES, - }, + tensor::{ceil, DataTypeError, ScalarType, Shape, TensorData, MINIMUM_BUFFER_SIZE_BYTES}, }; /// The maximum number of bindings in a binding group (defined by wgpu) @@ -47,7 +43,7 @@ enum GpuStep { Operator { pipeline: wgpu::ComputePipeline, bind_groups: Vec, - threads: (u32, u32, u32), + threads: (usize, usize, usize), output_tensors: Vec, }, @@ -355,15 +351,15 @@ impl GpuModel { let tensor = outputs[input.output_index].clone(); InferenceOutput::Tensor(tensor) } - NodeDefinition::Input(proto) => { - InferenceOutput::InferenceInput(proto.get_name().to_string()) + NodeDefinition::Input { name, .. } => { + InferenceOutput::InferenceInput(name.clone()) + } + NodeDefinition::Missing => { + unimplemented!("missing input as output"); } NodeDefinition::Outputs { .. } => { unimplemented!("output after output node") } - NodeDefinition::Missing => { - unimplemented!("optional input after output node") - } }, ); } @@ -426,7 +422,7 @@ impl GpuModel { while let NodeDefinition::Operator(ultimate_input_op_def) = ultimate_input.source_node.definition() { - if op_forwards_input(ultimate_input_op_def.proto.get_op_type()) { + if op_forwards_input(ultimate_input_op_def.get_op_type()) { assert_eq!(ultimate_input.source_node.inputs.len(), 1); ultimate_input = ultimate_input.source_node.inputs[0].clone(); } else { @@ -434,7 +430,7 @@ impl GpuModel { } } - let output_shape = &source_node_def.output_shapes[node_input.output_index]; + let output_shape = &source_node_def.output_shapes()[node_input.output_index]; buffer_manager.lease( ultimate_input.source_node.identifier(), ultimate_input.output_index, @@ -452,7 +448,7 @@ impl GpuModel { if outputs_readable { if let NodeDefinition::Operator(op_def) = &node.definition { // For these ops we just forward the buffer (so we should also forward readability) - if op_forwards_input(op_def.proto.get_op_type()) { + if op_forwards_input(op_def.get_op_type()) { nodes_readable.insert(source_node_identifier.clone()); } } @@ -461,8 +457,8 @@ impl GpuModel { // Tell the buffer manager we are producing an intermediate value; nodes that run 'before' us may reuse this buffer if let NodeDefinition::Operator(op_def) = &node.definition { - if !op_forwards_input(op_def.proto.get_op_type()) { - for (output_index, output_shape) in op_def.output_shapes.iter().enumerate() { + if !op_forwards_input(op_def.get_op_type()) { + for (output_index, output_shape) in op_def.output_shapes().iter().enumerate() { buffer_manager.release( node_identifier.clone(), output_index, @@ -537,7 +533,7 @@ impl GpuModel { NodeDefinition::Operator(op_def) => { // Can we use shared buffers for outputs of this node? let shared_buffers: Vec>>> = - (0..op_def.output_shapes.len()) + (0..op_def.output_shapes().len()) .map(|output_index| { let identifier = node.identifier(); buffer_manager @@ -573,50 +569,49 @@ impl GpuModel { // For tensor (initializer) nodes, we just create a buffer and fill it with the initializer data NodeDefinition::Tensor(tensor_def) => { let tensor_buffer = - Arc::new(tensor_def.buffer(&self.device, outputs_readable)?); + Arc::new(self.tensor_to_buffer(tensor_def, outputs_readable)?); output_tensors.push(GpuTensor { - shape: Shape::from( - ScalarType::from_i32(tensor_def.get_data_type())?, - tensor_def.get_dims(), - ), + shape: tensor_def.shape(), buffer: tensor_buffer.clone(), }); GpuStep::Initializer(tensor_buffer) } // For inputs we create an empty buffer that can be used at inference time to supply input data - NodeDefinition::Input(input_def) => { + NodeDefinition::Input { + name: input_name, + shape: input_shape, + } => { if outputs_readable { log::warn!( "it looks like you will be reading back inference input '{}' as output", - input_def.get_name() + input_name ); } - let input_shape = input_def.get_shape()?; let buffer_size_aligned = input_shape.buffer_bytes_aligned(); log::debug!( "creating input buffer for {} shape {} size {}", - input_def.get_name(), + input_name, input_shape, buffer_size_aligned ); let input_buffer = Arc::new(resource::buffer( &self.device, input_shape.buffer_bytes_aligned(), - input_def.get_name(), + input_name, // Usage is not COPY_SRC/MAP_READ even when outputs_readable is true; we'll deal with the special // case of reading back inputs as outputs separately. BufferUsages::STORAGE | BufferUsages::COPY_DST, )); output_tensors.push(GpuTensor { - shape: input_shape, + shape: input_shape.clone(), buffer: input_buffer.clone(), }); - GpuStep::Input(input_def.get_name().to_string(), input_buffer) + GpuStep::Input(input_name.clone(), input_buffer) } - NodeDefinition::Missing | NodeDefinition::Outputs { .. } => { + NodeDefinition::Outputs { .. } | NodeDefinition::Missing => { // Nothing to sequence GpuStep::None } @@ -635,8 +630,8 @@ impl GpuModel { /// Perform inference using this model and the specified inference inputs. pub async fn infer<'a>( &self, - inference_inputs: &HashMap>, - ) -> Result, GpuError> { + inference_inputs: &HashMap>, + ) -> Result>, GpuError> { log::info!("encode inference steps"); let mut encoder = self .device @@ -647,100 +642,65 @@ impl GpuModel { log::debug!("submit inference steps"); self.queue.submit(Some(encoder.finish())); log::info!("inference completed"); - self.read_outputs(inference_inputs).await + Ok(self.read_outputs(inference_inputs).await?.to_owned()) } /// Reads the relevant buffers for the requested inference outputs async fn read_outputs<'a>( &self, - inference_inputs: &HashMap>, - ) -> Result, GpuError> { - let mut output_data: HashMap = HashMap::new(); + inference_inputs: &HashMap>, + ) -> Result>, GpuError> { + let mut output_data: HashMap> = HashMap::new(); for (output_name, output_source) in &self.inference_outputs { - output_data.insert( - output_name.to_string(), - match output_source { - InferenceOutput::InferenceInput(input_name) => { - (&inference_inputs[input_name]).into() - } - InferenceOutput::Tensor(tensor) => { - tensor.read_to_vec(&self.device, &self.queue).await? - } - }, - ); + let v: TensorData<'static> = match output_source { + InferenceOutput::InferenceInput(input_name) => { + inference_inputs[input_name].clone().into_static() + } + InferenceOutput::Tensor(tensor) => { + tensor.read_to_vec(&self.device, &self.queue).await? + } + } + .to_owned(); + output_data.insert(output_name.to_string(), v); } Ok(output_data) } -} - -trait TensorProtoExtra { - fn buffer(&self, device: &wgpu::Device, readable: bool) -> Result; -} - -impl TensorProtoExtra for TensorProto { - /// Create a GPU buffer containing the data of this initializer - fn buffer(&self, device: &wgpu::Device, readable: bool) -> Result { - let scalar_type = ScalarType::from_i32(self.get_data_type())?; - let input_shape = Shape::from(scalar_type, self.get_dims()); - log::debug!( - "creating buffer for tensor {} shape {}", - self.get_name(), - input_shape - ); - match scalar_type { - ScalarType::F32 => { - let data = self.get_float_data(); - buffer_with_bytes( - device, - readable, - self.get_name(), - if !data.is_empty() { - bytemuck::cast_slice(data) - } else { - self.get_raw_data() - }, - ) - } - ScalarType::U8 => { - // WGSL doesn't support 8 bit unsigned integers, so we load them as 32 bit ints - log::warn!("initializers with uint8 data type are not supported, converting into int32 initializer"); - let ints: Vec = self - .get_raw_data() - .iter() - .map(|x| (*x).try_into()) - .collect::, _>>() - .map_err(|_e| GpuError::OutOfBoundsError)?; - let raw_data = bytemuck::cast_slice(&ints); - buffer_with_bytes(device, readable, self.get_name(), raw_data) - } - ScalarType::I64 => { + fn tensor_to_buffer(&mut self, tensor: &Tensor, readable: bool) -> Result { + match tensor.data() { + TensorData::F32(data) => buffer_with_bytes( + &self.device, + readable, + tensor.display_name(), + bytemuck::cast_slice(data), + ), + TensorData::I32(data) => buffer_with_bytes( + &self.device, + readable, + tensor.display_name(), + bytemuck::cast_slice(data), + ), + TensorData::I64(data) => { // WGSL doesn't support 64 bit integers, so we load 64 bit tensors as 32 bit ints log::warn!("initializers with int64 data type are not supported, converting into int32 initializer"); - let ints: Vec = self - .get_int64_data() + let ints: Vec = data .iter() .map(|x| (*x).try_into()) .collect::, _>>() .map_err(|_e| GpuError::OutOfBoundsError)?; - let raw_data = bytemuck::cast_slice(&ints); - buffer_with_bytes(device, readable, self.get_name(), raw_data) - } - ScalarType::I32 => { - let data = self.get_int32_data(); + buffer_with_bytes( - device, + &self.device, readable, - self.get_name(), - if !data.is_empty() { - bytemuck::cast_slice(data) - } else { - self.get_raw_data() - }, + tensor.display_name(), + bytemuck::cast_slice(&ints), ) } + TensorData::U8(data) => { + buffer_with_bytes(&self.device, readable, tensor.display_name(), data) + } } } } @@ -751,6 +711,7 @@ fn buffer_with_bytes( name: &str, raw_data: &[u8], ) -> Result { + log::info!("creating buffer: {} {}b", name, raw_data.len()); let buffer_usage = match readable { true => BufferUsages::STORAGE | BufferUsages::COPY_SRC, false => BufferUsages::STORAGE, @@ -775,7 +736,7 @@ fn op_forwards_input(op_type: &str) -> bool { ) } -impl<'model> OperatorDefinition<'model> { +impl OperatorDefinition { fn gpu_op( &self, device: &wgpu::Device, @@ -784,12 +745,10 @@ impl<'model> OperatorDefinition<'model> { input_tensors: &[GpuTensor], shared_buffers: &[Option>>], ) -> Result { - let proto = &self.proto; - // Some nodes have specific GPU implementations, match these here - if op_forwards_input(proto.get_op_type()) { + if op_forwards_input(self.get_op_type()) { // Some ops do nothing but forward their input - let value_shape = &self.output_shapes[0]; + let value_shape = &self.output_shapes()[0]; let output_tensor = GpuTensor { buffer: input_tensors[0].buffer.clone(), shape: value_shape.clone(), @@ -797,16 +756,15 @@ impl<'model> OperatorDefinition<'model> { return Ok(GpuStep::Forward(output_tensor)); } - let label = Some(proto.get_name()); + let label = Some(self.get_display_name()); // Create output buffers for this op node - let output_tensors: Vec = proto - .get_output() + let output_tensors: Vec = self + .output_shapes() .iter() .enumerate() - .map(|(output_index, output_name)| { - let value_shape = &self.output_shapes[output_index]; - + .map(|(output_index, value_shape)| { + let output_name = format!("{}_{}", self.get_display_name(), output_index); let buffer = match shared_buffers.get(output_index) { Some(Some(shared_buffer)) if !outputs_readable => { let mut shared_buffer = shared_buffer.borrow_mut(); @@ -817,7 +775,7 @@ impl<'model> OperatorDefinition<'model> { "creating non-shared buffer for output #{} ({}) of {} shaped {}", output_index, output_name, - proto.get_name(), + self.get_display_name(), value_shape ); @@ -844,17 +802,13 @@ impl<'model> OperatorDefinition<'model> { .collect(); let input_shapes: Vec<&Shape> = input_tensors.iter().map(|input| &input.shape).collect(); - let output_shapes: Vec<&Shape> = self.output_shapes.iter().collect(); + let output_shapes: Vec<&Shape> = self.output_shapes().iter().collect(); // Compile shader for node let CompiledNode { shader, threads } = - compile(proto, &input_shapes, &output_shapes, opset_version).map_err(|ce| { + compile(self, &input_shapes, &output_shapes, opset_version).map_err(|ce| { GpuError::CompileError { - node: if proto.has_name() { - proto.get_name().to_string() - } else { - proto.get_op_type().to_string() - }, + node: self.get_display_name().to_string(), error: ce, } })?; @@ -901,7 +855,7 @@ impl<'model> OperatorDefinition<'model> { }); // Create 'bind groups' (groups of bound buffers) - let number_of_groups = ceil(binding_counter as u64, MAX_BINDINGS_PER_GROUP as u64) as usize; + let number_of_groups = ceil(binding_counter, MAX_BINDINGS_PER_GROUP); for group_index in 0..number_of_groups { let group_range = group_index * MAX_BINDINGS_PER_GROUP ..usize::min( @@ -931,7 +885,7 @@ impl GpuStep { &self, queue: &wgpu::Queue, encoder: &mut CommandEncoder, - inputs: &HashMap, + inputs: &HashMap, ) -> Result<(), GpuError> { match self { GpuStep::None | GpuStep::Forward(_) | GpuStep::Initializer(_) => { @@ -947,21 +901,21 @@ impl GpuStep { log::debug!("write input data for {}", input_name); match input_data { - InputTensor::F32(float_input) => { + TensorData::F32(float_input) => { queue.write_buffer( input_buffer, 0, bytemuck::cast_slice(&resize(float_input.to_vec())), ); } - InputTensor::I32(int_input) => { + TensorData::I32(int_input) => { queue.write_buffer( input_buffer, 0, bytemuck::cast_slice(&resize(int_input.to_vec())), ); } - InputTensor::I64(int_input) => { + TensorData::I64(int_input) => { log::warn!("reading int64 input '{input_name}' as int32 (int64 is not supported for calculation but can be used as input as long as values fit in int32)"); let int32_input = int_input .iter() @@ -973,7 +927,7 @@ impl GpuStep { bytemuck::cast_slice(&resize(int32_input)), ); } - InputTensor::U8(int_input) => { + TensorData::U8(int_input) => { log::warn!("reading uint8 input as int32 (uint8 is not supported for calculation but can be used as input)"); let int32_input = int_input .iter() @@ -1003,7 +957,7 @@ impl GpuStep { compute_pass.set_bind_group(index as u32, bind_group, &[]); } let (x, y, z) = *threads; - compute_pass.dispatch_workgroups(x, y, z); + compute_pass.dispatch_workgroups(x as u32, y as u32, z as u32); Ok(()) } } @@ -1016,14 +970,14 @@ impl GpuTensor { &self, device: &wgpu::Device, queue: &wgpu::Queue, - ) -> Result { + ) -> Result, GpuError> { let shape = self.shape.clone(); #[cfg(target_arch = "wasm32")] { let buffer_slice = self.buffer.slice(..); let (sender, receiver) = - futures::channel::oneshot::channel::>(); + futures::channel::oneshot::channel::>(); wgpu::util::DownloadBuffer::read_buffer(device, queue, &buffer_slice, move |buffer| { // Called on download completed @@ -1050,39 +1004,40 @@ impl GpuTensor { wgpu::util::DownloadBuffer::read_buffer(device, queue, &buffer_slice, move |buffer| { // Called on download completed tx.send(match buffer { - Ok(bytes) => Ok(Self::read_bytes_to_vec(&bytes, shape)), + Ok(bytes) => Ok(bytes.to_vec()), Err(error) => Err(GpuError::BufferAsyncError(error)), }) .unwrap(); }); device.poll(wgpu::Maintain::Wait); // The callback will have been called by now due to poll(Wait) - rx.recv().unwrap() + let bytes_vec = rx.recv().unwrap()?; + Ok(Self::read_bytes_to_vec(&bytes_vec, shape)) } } - fn read_bytes_to_vec(output_data: &[A], shape: Shape) -> OutputTensor + fn read_bytes_to_vec(output_data: &[A], shape: Shape) -> TensorData<'static> where A: NoUninit, { // The actual buffer may be bigger than what we should return, because buffers have a minimum size in wgpu // Fetch the size we should expect so we can chop the buffer to the correct size - let output_buffer_size = shape.element_count() as usize; + let output_buffer_size = shape.element_count(); match shape.data_type { - ScalarType::F32 => { - OutputTensor::F32(bytemuck::cast_slice(output_data)[..output_buffer_size].to_vec()) - } - ScalarType::I32 => { - OutputTensor::I32(bytemuck::cast_slice(output_data)[..output_buffer_size].to_vec()) - } - ScalarType::U8 => { - OutputTensor::U8(bytemuck::cast_slice(output_data)[..output_buffer_size].to_vec()) - } + ScalarType::F32 => TensorData::F32(Cow::Owned( + bytemuck::cast_slice(output_data)[..output_buffer_size].to_vec(), + )), + ScalarType::I32 => TensorData::I32(Cow::Owned( + bytemuck::cast_slice(output_data)[..output_buffer_size].to_vec(), + )), + ScalarType::U8 => TensorData::U8(Cow::Owned( + bytemuck::cast_slice(output_data)[..output_buffer_size].to_vec(), + )), ScalarType::I64 => { log::warn!("reading int64 output as int32 because internally int64 scalars are not supported"); let result_ints: Vec = bytemuck::cast_slice(output_data)[..output_buffer_size].to_vec(); - OutputTensor::I64(result_ints.iter().map(|i| *i as i64).collect()) + TensorData::I64(Cow::Owned(result_ints.iter().map(|i| *i as i64).collect())) } } } diff --git a/wonnx/src/ir.rs b/wonnx/src/ir.rs index 11e15a90..c1065087 100644 --- a/wonnx/src/ir.rs +++ b/wonnx/src/ir.rs @@ -1,56 +1,263 @@ -//! DAG representation of ONNX ops allowing for transformations and optimizations before compilation -use crate::onnx::{ModelProto, NodeProto, TensorProto, ValueInfoProto}; -use crate::utils::{DataTypeError, Shape}; -use std::borrow::Cow; +//! DAG representation of ops allowing for transformations and optimizations before compilation +use crate::tensor::{DataTypeError, Shape, TensorData}; +use std::borrow::{Borrow, Cow}; +use std::convert::TryFrom; use std::fmt::Debug; use std::hash::Hash; use std::ptr; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; +#[derive(Clone, Debug)] +pub enum AttributeValue<'a> { + F32(f32), + I64(i64), + I64s(Cow<'a, [i64]>), + F32s(Cow<'a, [f32]>), + String(String), + Tensor(Tensor<'a>), +} + +impl<'a> AttributeValue<'a> { + pub fn into_static(self) -> AttributeValue<'static> { + match self { + AttributeValue::F32(f) => AttributeValue::F32(f), + AttributeValue::I64(f) => AttributeValue::I64(f), + AttributeValue::I64s(f) => AttributeValue::I64s(Cow::Owned(f.into_owned())), + AttributeValue::F32s(f) => AttributeValue::F32s(Cow::Owned(f.into_owned())), + AttributeValue::String(s) => AttributeValue::String(s), + AttributeValue::Tensor(t) => AttributeValue::Tensor(t.into_static()), + } + } +} + +impl TryFrom<&AttributeValue<'_>> for f32 { + type Error = (); + + fn try_from(value: &AttributeValue) -> Result { + if let AttributeValue::F32(v) = value { + Ok(*v) + } else { + Err(()) + } + } +} + +impl TryFrom<&AttributeValue<'_>> for i64 { + type Error = (); + + fn try_from(value: &AttributeValue) -> Result { + if let AttributeValue::I64(v) = value { + Ok(*v) + } else { + Err(()) + } + } +} + +impl TryFrom<&AttributeValue<'_>> for Vec { + type Error = (); + + fn try_from(value: &AttributeValue) -> Result { + if let AttributeValue::I64s(v) = value { + Ok(v.to_vec()) + } else { + Err(()) + } + } +} + +impl TryFrom<&AttributeValue<'_>> for String { + type Error = (); + + fn try_from(value: &AttributeValue) -> Result { + if let AttributeValue::String(v) = value { + Ok(v.clone()) + } else { + Err(()) + } + } +} + +impl TryFrom<&AttributeValue<'_>> for Tensor<'static> { + type Error = (); + + fn try_from(value: &AttributeValue) -> Result { + if let AttributeValue::Tensor(v) = value { + Ok(v.clone().into_static()) + } else { + Err(()) + } + } +} + +impl TryFrom<&AttributeValue<'_>> for Vec { + type Error = (); + + fn try_from(value: &AttributeValue) -> Result { + if let AttributeValue::F32s(v) = value { + Ok(v.to_vec()) + } else { + Err(()) + } + } +} + +impl From for AttributeValue<'_> { + fn from(value: i64) -> Self { + AttributeValue::I64(value) + } +} + +impl From for AttributeValue<'_> { + fn from(value: f32) -> Self { + AttributeValue::F32(value) + } +} + +impl From> for AttributeValue<'_> { + fn from(value: Vec) -> Self { + AttributeValue::F32s(Cow::Owned(value)) + } +} + +impl From> for AttributeValue<'_> { + fn from(value: Vec) -> Self { + AttributeValue::I64s(Cow::Owned(value)) + } +} + #[derive(Clone)] -pub struct OperatorDefinition<'model> { - pub(crate) proto: Cow<'model, NodeProto>, +pub struct OperatorDefinition { + pub(crate) op_type: String, + pub(crate) attributes: HashMap>, pub(crate) output_shapes: Vec, + pub(crate) display_name: String, +} + +#[derive(Error, Debug)] +#[error("did not find attribute '{attribute}' for node '{node_name}'")] +pub struct AttributeNotFoundError { + attribute: String, + node_name: String, } -impl<'model> OperatorDefinition<'model> { - pub fn from( - node: Cow<'model, NodeProto>, - value_shapes: &HashMap<&'model str, Shape>, - ) -> Result, IrError> { - let mut output_shapes: Vec = Vec::with_capacity(node.get_output().len()); - for output_name in node.get_output() { - if !value_shapes.contains_key(output_name.as_str()) { - return Err(IrError::OutputNodeNotFound(output_name.to_string())); +impl OperatorDefinition { + pub fn new( + op_type: &str, + output_shapes: Vec, + display_name: String, + ) -> OperatorDefinition { + OperatorDefinition { + op_type: op_type.to_string(), + attributes: HashMap::new(), + output_shapes, + display_name, + } + } + + pub fn output_shapes(&self) -> &[Shape] { + &self.output_shapes + } + + pub fn get_display_name(&self) -> &str { + &self.display_name + } + + pub fn append_attributes_from(&mut self, rhs: &Self) { + for (k, v) in rhs.attributes.iter() { + self.attributes.insert(k.clone(), v.clone()); + } + } + + pub fn set_attribute(&mut self, name: &str, inputs: impl Into>) { + let attribute: AttributeValue = inputs.into(); + self.attributes.insert(name.to_string(), attribute); + } + + pub fn set_op_type(&mut self, op_type: &str) { + self.op_type = op_type.to_string(); + } + + pub fn get_op_type(&self) -> &str { + &self.op_type + } + + pub fn get_attribute_value<'a, T>( + &'a self, + attribute: &str, + default: Option, + ) -> Result + where + T: TryFrom<&'a AttributeValue<'a>>, + { + match (self.attributes.get(attribute), default) { + (Some(attribute_value), _) => { + Ok( + T::try_from(attribute_value).map_err(|_| AttributeNotFoundError { + attribute: attribute.to_string(), + node_name: self.get_display_name().to_string(), + })?, + ) } + (None, Some(default_value)) => Ok(default_value), + (None, None) => Err(AttributeNotFoundError { + attribute: attribute.to_string(), + node_name: self.get_display_name().to_string(), + }), + } + } +} + +#[derive(Clone, Debug)] +pub struct Tensor<'a> { + pub(crate) data: TensorData<'a>, + pub(crate) dims: Vec, + pub(crate) display_name: String, +} - output_shapes.push(value_shapes[&output_name.as_str()].clone()); +impl<'a> Tensor<'a> { + pub fn dims(&self) -> &[usize] { + &self.dims + } + + pub fn into_static(self) -> Tensor<'static> { + Tensor { + data: self.data.into_static(), + dims: self.dims, + display_name: self.display_name, } - Ok(OperatorDefinition { - proto: node, - output_shapes, - }) + } + + pub fn shape(&self) -> Shape { + Shape::from(self.data.scalar_type(), &self.dims) + } + + pub fn display_name(&self) -> &str { + self.display_name.borrow() + } + + pub fn data(&self) -> &TensorData<'a> { + &self.data } } #[derive(Clone)] pub enum NodeDefinition<'model> { - Operator(Box>), - Tensor(Box>), - Input(&'model ValueInfoProto), + Operator(OperatorDefinition), + Tensor(Tensor<'model>), + Input { name: String, shape: Shape }, Outputs { names: Vec }, Missing, // A missing input (optional) } -static MISSING_OPTIONAL_INPUT: NodeDefinition<'static> = NodeDefinition::Missing; - -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Input<'model> { pub source_node: Arc>, pub output_index: usize, } +#[derive(Debug)] pub struct Node<'model> { pub definition: NodeDefinition<'model>, pub inputs: Vec>, @@ -72,39 +279,31 @@ pub enum IrError { } impl<'m> NodeDefinition<'m> { - pub fn get_name(&self) -> Cow<'_, str> { + pub fn get_display_name(&self) -> Cow<'_, str> { match self { // Nodes are identified by their first output's name, because node names are optional (and "only to be used // for diagnostic purposes" according to the ONNX IR specification) whereas output names are required and should be unique. - NodeDefinition::Operator(op_def) => Cow::from(&op_def.proto.get_output()[0]), - NodeDefinition::Tensor(t) => Cow::from(t.get_name()), - NodeDefinition::Input(i) => Cow::from(i.get_name()), + NodeDefinition::Operator(op_def) => Cow::from(op_def.get_display_name()), + NodeDefinition::Tensor(t) => Cow::from(t.display_name()), + NodeDefinition::Input { name, .. } => Cow::from(name), NodeDefinition::Outputs { .. } => Cow::from(" "), NodeDefinition::Missing => Cow::from(""), } } } -impl NodeProto { - // Nodes are identified by their first output's name, because node names are optional (and "only to be used - // for diagnostic purposes" according to the ONNX IR specification) whereas output names are required and should be unique. - fn unique_name(&self) -> String { - self.get_output()[0].clone() - } -} - impl<'model> Node<'model> { - pub fn new(variant: NodeDefinition<'model>) -> Node<'model> { + pub fn new(variant: NodeDefinition<'model>, inputs: Vec>) -> Node<'model> { Node { definition: variant, - inputs: vec![], + inputs, } } pub fn is_dynamic(&self) -> bool { matches!( self.definition, - NodeDefinition::Operator(..) | NodeDefinition::Input(..) + NodeDefinition::Operator(..) | NodeDefinition::Input { .. } ) } @@ -117,195 +316,16 @@ impl<'model> Node<'model> { pub fn definition(&self) -> &NodeDefinition<'model> { &self.definition } - - /// Construct part of the intermediate representation tree for the indicated node. - pub fn from_node<'a>( - node: Cow<'model, NodeProto>, - value_shapes: &HashMap<&'model str, Shape>, - node_definitions_by_output: &'a HashMap>, - nodes_by_unique_name: &mut HashMap>>, - ) -> Result>, IrError> { - let node_name = node.unique_name(); - // Did we already translate this node before? - if nodes_by_unique_name.contains_key(&node_name) { - let n = nodes_by_unique_name.get(&node_name).unwrap(); - return Ok(n.clone()); - } - - let inputs: Result>, IrError> = node - .get_input() - .iter() - .map(|input_name: &String| { - let my_input_name = input_name.clone(); - let source_node_definition = node_definitions_by_output - .get(&my_input_name) - .unwrap_or(&MISSING_OPTIONAL_INPUT); - - Ok(match source_node_definition { - // The source is another op - continue translating that node - NodeDefinition::Operator(source_node_proto) => Input { - source_node: Node::from_node( - source_node_proto.proto.clone(), - value_shapes, - node_definitions_by_output, - nodes_by_unique_name, - )?, - output_index: source_node_proto - .proto - .get_output() - .iter() - .position(|s| s == input_name) - .ok_or_else(|| IrError::OutputNodeNotFound(input_name.to_string()))?, - }, - _ => { - // The source is an initializer or model onput - let source_name = source_node_definition.get_name().to_string(); - - Input { - output_index: 0, - // Did we already translate this node? - source_node: match nodes_by_unique_name.get(&source_name) { - Some(node) => node.clone(), - None => { - let node = Arc::new(Node::new(source_node_definition.clone())); - nodes_by_unique_name.insert(source_name, node.clone()); - node - } - }, - } - } - }) - }) - .collect(); - - let translated = Arc::new(Node { - definition: NodeDefinition::Operator(Box::new(OperatorDefinition::from( - node.clone(), - value_shapes, - )?)), - inputs: inputs?, - }); - nodes_by_unique_name.insert(node.unique_name(), translated.clone()); - Ok(translated) - } - - /// Construct an intermediate representation graph for calculating the output with the specified name. - pub fn from_model( - model: &'model ModelProto, - outputs: Option<&[String]>, - ) -> Result>, IrError> { - // Collect value shapes - let mut value_shapes: HashMap<&'model str, Shape> = HashMap::new(); - for vi in model.get_graph().get_value_info() { - value_shapes.insert(vi.get_name(), vi.get_shape()?); - } - - for vi in model.get_graph().get_output() { - let output_name = vi.get_name(); - if !output_name.is_empty() { - value_shapes.insert(output_name, vi.get_shape()?); - } - } - - // Sort nodes by output nodes - let mut node_definitions_by_output = HashMap::>::new(); - for node in model.get_graph().get_node().iter() { - let node_def = NodeDefinition::Operator(Box::new(OperatorDefinition::from( - Cow::Borrowed(node), - &value_shapes, - )?)); - for output in node.get_output() { - if !output.is_empty() { - node_definitions_by_output.insert(output.to_string(), node_def.clone()); - } - } - } - - // Collect intializer info - for initializer in model.get_graph().get_initializer().iter() { - node_definitions_by_output.insert( - initializer.get_name().to_string(), - NodeDefinition::Tensor(Box::new(Cow::Borrowed(initializer))), - ); - } - - let output_names: Vec = match outputs { - Some(outputs) => outputs.to_vec(), - None => model - .get_graph() - .get_output() - .iter() - .map(|x| x.get_name().to_string()) - .collect(), - }; - - // Collect input name - for input in model.get_graph().get_input().iter() { - if !node_definitions_by_output.contains_key(input.get_name()) { - node_definitions_by_output - .insert(input.get_name().to_string(), NodeDefinition::Input(input)); - } else { - log::debug!( - "Skipping input definition {}: already defined", - input.get_name() - ); - } - } - - let mut nodes_by_name = HashMap::new(); - - let output_nodes: Result>, IrError> = output_names - .iter() - .map(|output_name| { - let output_node = model - .get_graph() - .get_node() - .iter() - .find(|x| -> bool { x.get_output().contains(output_name) }) - .ok_or_else(|| IrError::OutputNodeNotFound(output_name.clone()))?; - - let source_node = Node::<'model>::from_node( - Cow::Borrowed(output_node), - &value_shapes, - &node_definitions_by_output, - &mut nodes_by_name, - )?; - - let output_index = output_node - .get_output() - .iter() - .position(|s| s == output_name) - .ok_or_else(|| IrError::OutputNodeNotFound(output_name.clone()))?; - - Ok(Input { - source_node, - output_index, - }) - }) - .collect(); - - Ok(Arc::new(Node { - definition: NodeDefinition::Outputs { - names: output_names, - }, - inputs: output_nodes?, - })) - } } impl<'model> Debug for NodeDefinition<'model> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { NodeDefinition::Operator(def) => { - write!( - f, - "op: {} ({})", - def.proto.get_name(), - def.proto.get_op_type() - ) + write!(f, "op: {} ({})", def.get_display_name(), def.get_op_type()) } - NodeDefinition::Tensor(def) => write!(f, "tensor {}", def.get_name()), - NodeDefinition::Input(def) => write!(f, "input {}", def.get_name()), + NodeDefinition::Tensor(def) => write!(f, "tensor {}", def.display_name()), + NodeDefinition::Input { name, .. } => write!(f, "input {}", name), NodeDefinition::Outputs { .. } => write!(f, "outputs"), NodeDefinition::Missing => write!(f, "missing (optional)"), } @@ -320,7 +340,7 @@ impl<'model> Debug for NodeIdentifier<'model> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("NodeIdentifier") .field(&Arc::as_ptr(&self.0)) - .field(&self.0.definition.get_name()) + .field(&self.0.definition.get_display_name()) .finish() } } diff --git a/wonnx/src/lib.rs b/wonnx/src/lib.rs index de622522..0423b546 100644 --- a/wonnx/src/lib.rs +++ b/wonnx/src/lib.rs @@ -1,10 +1,13 @@ +pub mod builder; +pub mod ir; +pub mod onnx; +pub mod onnx_model; +pub mod tensor; + mod compiler; mod gpu; -mod ir; -pub mod onnx; mod optimizer; mod resource; -pub mod utils; #[macro_use] extern crate lazy_static; @@ -12,13 +15,13 @@ extern crate lazy_static; pub use compiler::CompileError; pub use gpu::GpuError; use ir::IrError; +use onnx_model::OpsetError; pub use optimizer::constant_of_shape_output; -use optimizer::{Optimizer, OptimizerError}; -use protobuf::{self, Message, ProtobufError}; +use optimizer::OptimizerError; +use protobuf::{self, ProtobufError}; use std::collections::HashMap; -use std::path::Path; use std::result::Result; -use utils::{get_opset_version, DataTypeError, InputTensor, OpsetError, OutputTensor}; +use tensor::{DataTypeError, TensorData}; use crate::gpu::GpuModel; use thiserror::Error; @@ -39,14 +42,6 @@ pub enum WonnxError { } /// An inference [session](Session) represents a model that is loaded and ready to perform inference on the GPU. -/// -/// # Examples -/// -/// Basic usage: -/// -/// ```ignore -/// let mut session = Session::from_path("path/to/model.onnx").await.unwrap(); -/// ``` pub struct Session { gpu_model: GpuModel, } @@ -112,67 +107,11 @@ impl Default for SessionConfig { } impl Session { - // Read an ONNX model from a path and create a session, using default [session config](SessionConfig). - pub async fn from_path>(path: P) -> Result { - let model = onnx::ModelProto::parse_from_bytes(&std::fs::read(path)?)?; - Session::from_model(model).await - } - - // Read an ONNX model from a path and create a session using the specified [session config](SessionConfig). - pub async fn from_path_with_config>( - path: P, - config: &SessionConfig, - ) -> Result { - let model = onnx::ModelProto::parse_from_bytes(&std::fs::read(path)?)?; - Session::from_model_with_config(model, config).await - } - - /// Read an ONNX model from bytes and create a session, using default [session config](SessionConfig). - pub async fn from_bytes(bytes: &[u8]) -> Result { - let model = onnx::ModelProto::parse_from_bytes(bytes)?; - Session::from_model(model).await - } - - /// Read an ONNX model from bytes and create a session with the specified [session config](SessionConfig). - pub async fn from_bytes_with_config( - bytes: &[u8], - config: &SessionConfig, - ) -> Result { - let model = onnx::ModelProto::parse_from_bytes(bytes)?; - Session::from_model_with_config(model, config).await - } - - /// Create a session using the provided [`onnx::ModelProto`] and [session config](SessionConfig). - pub async fn from_model_with_config( - model: onnx::ModelProto, - config: &SessionConfig, - ) -> Result { - let (device, queue) = resource::request_device_queue().await; - - // Optimize and compile the model graph to a set of buffers and 'builders' which can basically run GPU shader code referencing these buffers - let onnx_opset_version = get_opset_version(&model) - .map_err(SessionError::OpsetError)? - .ok_or(SessionError::UnknownOnnxOpsetVersion)?; - - let mut optimizer = Optimizer::new(onnx_opset_version); - let ir = optimizer - .optimize(ir::Node::from_model(&model, config.outputs.as_deref())?) - .await?; - let gpu_model = GpuModel::from(ir, device, queue, onnx_opset_version)?; - - Ok(Session { gpu_model }) - } - - /// Create a Session given an ONNX model, using default configuration. - pub async fn from_model(model: onnx::ModelProto) -> Result { - Self::from_model_with_config(model, &SessionConfig::new()).await - } - /// Perform inference given the inputs provided and return all the outputs the model was compiled to return. pub async fn run<'a>( &self, - inputs: &HashMap>, - ) -> Result, SessionError> { + inputs: &HashMap>, + ) -> Result>, SessionError> { Ok(self.gpu_model.infer(inputs).await?) } } diff --git a/wonnx/src/onnx_model.rs b/wonnx/src/onnx_model.rs new file mode 100644 index 00000000..56a3665c --- /dev/null +++ b/wonnx/src/onnx_model.rs @@ -0,0 +1,853 @@ +//! Conversion of ONNX models to WONNX IR +use crate::gpu::GpuModel; +use crate::ir::AttributeValue; +use crate::ir::Input; +use crate::ir::IrError; +use crate::ir::Node; +use crate::ir::NodeDefinition; +use crate::ir::OperatorDefinition; +use crate::ir::Tensor; +use crate::onnx; +use crate::onnx::AttributeProto; +use crate::onnx::AttributeProto_AttributeType; +use crate::onnx::GraphProto; +use crate::onnx::ModelProto; +use crate::onnx::NodeProto; +use crate::onnx::OperatorSetIdProto; +use crate::onnx::TensorProto; +use crate::onnx::TensorProto_DataType; +use crate::onnx::TensorShapeProto; +use crate::onnx::TensorShapeProto_Dimension; +use crate::onnx::TypeProto; +use crate::onnx::TypeProto_Tensor; +use crate::onnx::TypeProto_oneof_value; +use crate::onnx::ValueInfoProto; +use crate::optimizer::Optimizer; +use crate::resource::request_device_queue; +use crate::tensor::DataTypeError; +use crate::tensor::ScalarType; +use crate::tensor::Shape; +use crate::tensor::TensorData; +use crate::Session; +use crate::SessionConfig; +use crate::SessionError; +use protobuf::Message; +use protobuf::ProtobufEnum; +use protobuf::RepeatedField; +use std::borrow::Cow; +use std::collections::HashMap; +use std::convert::From; +use std::convert::Into; +use std::convert::TryFrom; +use std::path::Path; +use std::str::from_utf8; +use std::sync::Arc; +use thiserror::Error; + +impl TensorProto { + pub fn from(value: TensorData, dims: Vec) -> Self { + let mut tensor = TensorProto::new(); + match value { + TensorData::F32(v) => { + tensor.set_data_type(ScalarType::F32.to_onnx_datatype().value()); + tensor.set_float_data(v.to_vec()); + } + TensorData::I32(v) => { + tensor.set_data_type(ScalarType::I32.to_onnx_datatype().value()); + tensor.set_int32_data(v.to_vec()); + } + TensorData::I64(v) => { + tensor.set_data_type(ScalarType::I64.to_onnx_datatype().value()); + tensor.set_int64_data(v.to_vec()); + } + TensorData::U8(v) => { + tensor.set_data_type(ScalarType::U8.to_onnx_datatype().value()); + tensor.set_raw_data(v.to_vec()); + } + } + tensor.set_dims(dims); + tensor + } +} + +impl<'a> TryFrom<&'a TensorProto> for TensorData<'a> { + type Error = DataTypeError; + + fn try_from(value: &'a TensorProto) -> Result { + Ok(match ScalarType::from_onnx_i32(value.get_data_type())? { + ScalarType::F32 => TensorData::F32(Cow::Borrowed(value.get_float_data())), + ScalarType::I64 => TensorData::I64(Cow::Borrowed(value.get_int64_data())), + ScalarType::I32 => TensorData::I32(Cow::Borrowed(value.get_int32_data())), + ScalarType::U8 => TensorData::U8(Cow::Borrowed(value.get_raw_data())), + }) + } +} + +impl ScalarType { + pub fn from_onnx_i32(onnx: i32) -> Result { + let onnx_dt = + TensorProto_DataType::from_i32(onnx).ok_or(DataTypeError::NotRecognized(onnx))?; + Self::from(onnx_dt) + } + + pub fn from(onnx: TensorProto_DataType) -> Result { + Ok(match onnx { + TensorProto_DataType::FLOAT => ScalarType::F32, + TensorProto_DataType::INT64 => ScalarType::I64, + TensorProto_DataType::INT32 => ScalarType::I32, + TensorProto_DataType::UINT8 => ScalarType::U8, + _ => return Err(DataTypeError::NotSupported(onnx.value())), + }) + } + + pub fn to_onnx_datatype(&self) -> TensorProto_DataType { + match self { + ScalarType::F32 => TensorProto_DataType::FLOAT, + ScalarType::I64 => TensorProto_DataType::INT64, + ScalarType::I32 => TensorProto_DataType::INT32, + ScalarType::U8 => TensorProto_DataType::UINT8, + } + } +} + +impl ValueInfoProto { + pub fn get_shape(&self) -> Result { + Ok(match &self.get_field_type().value { + Some(t) => match t { + onnx::TypeProto_oneof_value::tensor_type(tensor_proto) => Shape::from( + ScalarType::from_onnx_i32(tensor_proto.get_elem_type())?, + self.get_field_type() + .get_tensor_type() + .get_shape() + .get_dim() + .iter() + .map(|x| { + if x.has_dim_param() { + return Err(DataTypeError::ParametrizedDimensionUnsupported( + x.get_dim_param().to_string(), + )); + } + Ok(x.get_dim_value() as usize) + }) + .collect::, DataTypeError>>()? + .as_slice(), + ), + onnx::TypeProto_oneof_value::sequence_type(_) => todo!(), + onnx::TypeProto_oneof_value::map_type(_) => todo!(), + onnx::TypeProto_oneof_value::optional_type(_) => todo!(), + onnx::TypeProto_oneof_value::sparse_tensor_type(_) => todo!(), + }, + None => return Err(DataTypeError::Undefined), + }) + } + + pub fn set_shape(&mut self, shape: &Shape) { + let mut tpt = TypeProto_Tensor::new(); + tpt.set_elem_type(shape.data_type.to_onnx_datatype().value()); + + let mut tsp = TensorShapeProto::new(); + tsp.dim.extend(shape.dims.iter().map(|x| { + let mut tspd = TensorShapeProto_Dimension::new(); + tspd.set_dim_value(*x as i64); + tspd + })); + tpt.set_shape(tsp); + + let mut tp = TypeProto::new(); + tp.value = Some(TypeProto_oneof_value::tensor_type(tpt)); + self.set_field_type(tp); + } +} + +/// Shorthand method to define an ONNX tensor with the specified name and shape (data type is f32) +pub fn onnx_tensor(name: &str, dimensions: &[i64]) -> onnx::ValueInfoProto { + onnx_tensor_of_type(name, dimensions, TensorProto_DataType::FLOAT) +} + +/// Shorthand method to define an ONNX tensor with the specified name, shape and data type +pub fn onnx_tensor_of_type( + name: &str, + dimensions: &[i64], + data_type: TensorProto_DataType, +) -> onnx::ValueInfoProto { + let mut dim_value = vec![]; + for dimension in dimensions { + let mut dim_channel = onnx::TensorShapeProto_Dimension::new(); + dim_channel.set_dim_value(*dimension); + dim_value.push(dim_channel); + } + + let mut shape_tensor_proto = onnx::TensorShapeProto::new(); + shape_tensor_proto.set_dim(protobuf::RepeatedField::from(dim_value)); + + let mut type_proto_tensor = onnx::TypeProto_Tensor::new(); + type_proto_tensor.set_elem_type(data_type.value()); + type_proto_tensor.set_shape(shape_tensor_proto); + + let mut type_proto = onnx::TypeProto::new(); + type_proto.set_tensor_type(type_proto_tensor); + + let mut tensor = onnx::ValueInfoProto::new(); + tensor.set_name(name.to_string()); + tensor.set_field_type(type_proto); + + tensor +} + +pub fn onnx_initializer(name: &str, data: Vec, dimensions: Vec) -> onnx::TensorProto { + let mut initializer = crate::onnx::TensorProto::new(); + assert_eq!( + dimensions.iter().cloned().product::() as usize, + data.len() + ); + initializer.set_dims(dimensions); + initializer.set_name(name.to_string()); + initializer.set_data_type(TensorProto_DataType::FLOAT.value()); + initializer.set_float_data(data); + initializer +} + +pub fn onnx_initializer_int64( + name: &str, + data: Vec, + dimensions: Vec, +) -> onnx::TensorProto { + let mut initializer = crate::onnx::TensorProto::new(); + assert_eq!( + dimensions.iter().cloned().product::() as usize, + data.len() + ); + initializer.set_name(name.to_string()); + initializer.set_dims(dimensions); + initializer.set_data_type(TensorProto_DataType::INT64.value()); + initializer.set_int64_data(data); + initializer +} + +pub fn onnx_attribute(name: &str, inputs: impl Into) -> onnx::AttributeProto { + let mut attributes: onnx::AttributeProto = inputs.into(); + attributes.set_name(name.to_string()); + attributes +} + +/// Create a node - the node name will be set to the name of the first output +pub fn onnx_node( + inputs: Vec<&str>, + outputs: Vec<&str>, + op_type: &str, + attributes: Vec, +) -> onnx::NodeProto { + let mut node = crate::onnx::NodeProto::new(); + + node.set_op_type(op_type.to_string()); + node.set_name(outputs[0].to_string()); + node.set_input(protobuf::RepeatedField::from( + inputs + .iter() + .map(|s| s.to_string()) + .collect::>(), + )); + node.set_output(protobuf::RepeatedField::from( + outputs + .iter() + .map(|s| s.to_string()) + .collect::>(), + )); + node.set_attribute(protobuf::RepeatedField::from(attributes)); + node +} + +pub fn onnx_graph( + inputs: Vec, + outputs: Vec, + mut infos: Vec, + initializers: Vec, + nodes: Vec, +) -> onnx::GraphProto { + let mut graph = onnx::GraphProto::new(); + graph.set_node(protobuf::RepeatedField::from(nodes)); + graph.set_input(protobuf::RepeatedField::from(inputs)); + graph.set_output(protobuf::RepeatedField::from(outputs)); + + // Auto-generate tensor information for initializers so users don't have to specify those + for i in &initializers { + infos.push(onnx_tensor_of_type( + i.get_name(), + i.get_dims(), + onnx::TensorProto_DataType::from_i32(i.get_data_type()).unwrap(), + )); + } + + graph.set_initializer(protobuf::RepeatedField::from(initializers)); + graph.set_value_info(protobuf::RepeatedField::from(infos)); + graph +} + +pub fn onnx_model_with_opset(graph: onnx::GraphProto, opset_version: i64) -> onnx::ModelProto { + let mut model = crate::onnx::ModelProto::new(); + let mut onnx_opset_import = OperatorSetIdProto::new(); + onnx_opset_import.set_domain("".to_string()); + onnx_opset_import.set_version(opset_version); + model.set_opset_import(RepeatedField::from_slice(&[onnx_opset_import])); + model.set_graph(graph); + model +} + +pub fn onnx_model(graph: onnx::GraphProto) -> onnx::ModelProto { + onnx_model_with_opset(graph, 13) +} + +impl From> for onnx::AttributeProto { + fn from(value: Vec) -> Self { + let mut attributes = crate::onnx::AttributeProto::new(); + attributes.set_ints(value); + attributes.set_field_type(AttributeProto_AttributeType::INTS); + attributes + } +} + +impl From> for onnx::AttributeProto { + fn from(value: Vec) -> Self { + let mut attributes = crate::onnx::AttributeProto::new(); + attributes.set_floats(value); + attributes.set_field_type(AttributeProto_AttributeType::FLOATS); + attributes + } +} + +impl From for onnx::AttributeProto { + fn from(value: f32) -> Self { + let mut attributes = crate::onnx::AttributeProto::new(); + attributes.set_f(value); + attributes.set_field_type(AttributeProto_AttributeType::FLOAT); + attributes + } +} + +impl From for onnx::AttributeProto { + fn from(value: i64) -> Self { + let mut attributes = crate::onnx::AttributeProto::new(); + attributes.set_i(value); + attributes.set_field_type(AttributeProto_AttributeType::INT); + attributes + } +} + +impl From for onnx::AttributeProto { + fn from(value: String) -> Self { + let mut attributes = crate::onnx::AttributeProto::new(); + attributes.set_s(value.into_bytes()); + attributes.set_field_type(AttributeProto_AttributeType::STRING); + attributes + } +} + +impl From<&str> for onnx::AttributeProto { + fn from(value: &str) -> Self { + let mut attributes = crate::onnx::AttributeProto::new(); + attributes.set_s(value.to_string().into_bytes()); + attributes.set_field_type(AttributeProto_AttributeType::STRING); + attributes + } +} + +impl From for onnx::AttributeProto { + fn from(value: TensorProto) -> Self { + let mut attributes = crate::onnx::AttributeProto::new(); + attributes.set_t(value); + attributes.set_field_type(AttributeProto_AttributeType::TENSOR); + attributes + } +} + +// Attribute to value conversions +impl From<&onnx::AttributeProto> for Vec { + fn from(value: &onnx::AttributeProto) -> Self { + value.get_ints().to_vec() + } +} + +impl From<&onnx::AttributeProto> for TensorProto { + fn from(value: &onnx::AttributeProto) -> Self { + value.get_t().clone() + } +} + +impl<'a> From<&'a onnx::AttributeProto> for &'a TensorProto { + fn from(value: &'a onnx::AttributeProto) -> Self { + value.get_t() + } +} + +impl From<&onnx::AttributeProto> for Vec { + fn from(value: &onnx::AttributeProto) -> Self { + value.get_floats().to_vec() + } +} + +impl From<&onnx::AttributeProto> for f32 { + fn from(value: &onnx::AttributeProto) -> Self { + value.get_f() + } +} + +impl From<&onnx::AttributeProto> for i64 { + fn from(value: &onnx::AttributeProto) -> Self { + value.get_i() + } +} + +impl From<&onnx::AttributeProto> for String { + fn from(value: &onnx::AttributeProto) -> Self { + from_utf8(value.get_s()).unwrap().to_string() + } +} + +#[derive(Error, Debug)] +pub enum OpsetError { + #[error("more than one ONNX opset was specified: {0} and {1}")] + DuplicateOnnxOpset(i64, i64), + + #[error("the model references an unknown opset: '{0}'")] + UnknownOpset(String), +} + +pub fn get_opset_version(model: &ModelProto) -> Result, OpsetError> { + // Find the version of the ONNX operator set this model is using (this is useful because some operators' specifications change over time). + // Note, if any other op set than the ONNX operator set is referenced, we cannot run the model. + // See https://github.com/onnx/onnx/blob/master/docs/Versioning.md#operator-sets + let mut onnx_opset_version = None; + for opset_import in model.get_opset_import() { + match opset_import.get_domain() { + "" => { + // This is a reference to the ONNX specification op set + if let Some(onnx_version) = onnx_opset_version { + if opset_import.get_version() != onnx_version { + return Err(OpsetError::DuplicateOnnxOpset( + onnx_version, + opset_import.get_version(), + )); + } + } else { + onnx_opset_version = Some(opset_import.get_version()); + } + } + some_other_opset => { + return Err(OpsetError::UnknownOpset(some_other_opset.to_string())); + } + } + } + Ok(onnx_opset_version) +} + +pub(crate) fn to_tensor<'model>( + proto: &'model TensorProto, +) -> Result, DataTypeError> { + let scalar_type = ScalarType::from_onnx_i32(proto.get_data_type())?; + let dims = proto + .get_dims() + .iter() + .map(|x| *x as usize) + .collect::>(); + log::debug!( + "creating tensor for ONNX tensor {} shape {}", + proto.get_name(), + Shape::from(scalar_type, &dims) + ); + + let tensor_data: TensorData<'model> = match scalar_type { + ScalarType::F32 => { + let fd = proto.get_float_data(); + if fd.is_empty() { + let fd: &[f32] = bytemuck::cast_slice(proto.get_raw_data()); + TensorData::F32(Cow::from(fd)) + } else { + TensorData::F32(Cow::from(fd)) + } + } + ScalarType::I32 => { + let fd = proto.get_int32_data(); + if fd.is_empty() { + let fd: &[i32] = bytemuck::cast_slice(proto.get_raw_data()); + TensorData::I32(Cow::from(fd)) + } else { + TensorData::I32(Cow::from(fd)) + } + } + ScalarType::I64 => { + let fd = proto.get_int64_data(); + if fd.is_empty() { + let fd: &[i64] = bytemuck::cast_slice(proto.get_raw_data()); + TensorData::I64(Cow::from(fd)) + } else { + TensorData::I64(Cow::from(fd)) + } + } + ScalarType::U8 => TensorData::U8(Cow::from(proto.get_raw_data())), + }; + + Ok(Tensor { + data: tensor_data, + dims, + display_name: proto.get_name().to_string(), + }) +} + +impl<'a> From<&'a AttributeProto> for AttributeValue<'a> { + fn from(value: &'a AttributeProto) -> Self { + match value.get_field_type() { + AttributeProto_AttributeType::INT => AttributeValue::I64(value.get_i()), + AttributeProto_AttributeType::FLOAT => AttributeValue::F32(value.get_f()), + AttributeProto_AttributeType::INTS => AttributeValue::I64s(Cow::from(value.get_ints())), + AttributeProto_AttributeType::FLOATS => { + AttributeValue::F32s(Cow::from(value.get_floats())) + } + AttributeProto_AttributeType::STRING => { + AttributeValue::String(from_utf8(value.get_s()).unwrap().to_string()) + } + AttributeProto_AttributeType::TENSOR => { + AttributeValue::Tensor(to_tensor(value.get_t()).unwrap()) + } + _ => unimplemented!("attribute field type {}", value.get_field_type().value()), + } + } +} + +impl OperatorDefinition { + pub fn from(node: &NodeProto, output_shapes: Vec) -> OperatorDefinition { + assert_eq!(node.get_output().len(), output_shapes.len()); + let mut attributes = HashMap::new(); + for attr in node.get_attribute() { + attributes.insert( + attr.get_name().to_string(), + AttributeValue::from(attr).into_static(), + ); + } + + OperatorDefinition { + op_type: node.get_op_type().to_string(), + attributes, + output_shapes, + display_name: node.get_output()[0].to_string(), + } + } +} + +impl<'model> Node<'model> { + /// Construct part of the intermediate representation tree for the indicated node. + pub fn from_node<'a>( + node: Cow<'model, NodeProto>, + value_shapes: &HashMap<&'model str, Shape>, + node_definitions_by_output: &'a HashMap>, + nodes_by_output_names: &mut HashMap>>, + ) -> Result>, IrError> { + for output_name in node.get_output() { + if nodes_by_output_names.contains_key(output_name) { + let n = nodes_by_output_names.get(output_name).unwrap(); + return Ok(n.clone()); + } + } + + let inputs: Result>, IrError> = node + .get_input() + .iter() + .map(|input_name: &String| { + let my_input_name = input_name.clone(); + + // An empty input name signifies missing + if input_name.is_empty() { + return Ok(Input { + source_node: Arc::new(Node::new(NodeDefinition::Missing, vec![])), + output_index: 0, + }); + } + + Ok(match node_definitions_by_output.get(&my_input_name) { + Some(source_node_proto) => { + // The source is another op - continue translating that node + Input { + source_node: Node::from_node( + source_node_proto.clone(), + value_shapes, + node_definitions_by_output, + nodes_by_output_names, + )?, + output_index: source_node_proto + .get_output() + .iter() + .position(|s| s == input_name) + .ok_or_else(|| { + IrError::OutputNodeNotFound(input_name.to_string()) + })?, + } + } + None => { + Input { + output_index: 0, + // Did we already translate this node? + source_node: match nodes_by_output_names.get(input_name) { + Some(node) => node.clone(), + None => { + return Err(IrError::InputNodeNotFound { + target_node_name: node.get_name().to_string(), + input_name: input_name.clone(), + }) + } + }, + } + } + }) + }) + .collect(); + + // Obtain output shapes + let mut output_shapes: Vec = Vec::with_capacity(node.get_output().len()); + for output_name in node.get_output() { + if !value_shapes.contains_key(output_name.as_str()) { + return Err(IrError::OutputNodeNotFound(output_name.to_string())); + } + + output_shapes.push(value_shapes[&output_name.as_str()].clone()); + } + + let translated = Arc::new(Node { + definition: NodeDefinition::Operator(OperatorDefinition::from(&node, output_shapes)), + inputs: inputs?, + }); + + // Register the translated node by all of its output names + for output_name in node.get_output() { + nodes_by_output_names.insert(output_name.clone(), translated.clone()); + } + + Ok(translated) + } + + /// Construct an intermediate representation graph for calculating the output with the specified name. + pub fn from_model( + model: &'model ModelProto, + outputs: Option<&[String]>, + ) -> Result>, IrError> { + let graph: &'model GraphProto = model.get_graph(); + + // Collect value shapes + let mut value_shapes: HashMap<&'model str, Shape> = HashMap::new(); + for vi in graph.get_value_info() { + value_shapes.insert(vi.get_name(), vi.get_shape()?); + } + + for vi in graph.get_output() { + let output_name = vi.get_name(); + if !output_name.is_empty() { + value_shapes.insert(output_name, vi.get_shape()?); + } + } + + // Sort nodes by output nodes + let mut node_definitions_by_output = HashMap::>::new(); + for node in graph.get_node().iter() { + for output in node.get_output() { + if !output.is_empty() { + node_definitions_by_output.insert(output.to_string(), Cow::Borrowed(node)); + } + } + } + + let mut nodes_by_output_name = HashMap::new(); + + // Translate initializers + for initializer in graph.initializer.iter() { + nodes_by_output_name.insert( + initializer.get_name().to_string(), + Arc::new(Node::new( + NodeDefinition::Tensor(to_tensor(initializer)?), + vec![], + )), + ); + } + + // Translate inputs + for input in model.get_graph().get_input().iter() { + if !nodes_by_output_name.contains_key(input.get_name()) { + nodes_by_output_name.insert( + input.get_name().to_string(), + Arc::new(Node::new( + NodeDefinition::Input { + name: input.get_name().to_string(), + shape: input.get_shape()?, + }, + vec![], + )), + ); + } else { + log::warn!( + "Skipping input definition {}: already defined", + input.get_name() + ); + } + } + + let output_names: Vec = match outputs { + Some(outputs) => outputs.to_vec(), + None => model + .get_graph() + .get_output() + .iter() + .map(|x| x.get_name().to_string()) + .collect(), + }; + + let output_nodes: Result>, IrError> = output_names + .iter() + .map(|output_name| { + let output_node = model + .get_graph() + .get_node() + .iter() + .find(|x| -> bool { x.get_output().contains(output_name) }) + .ok_or_else(|| IrError::OutputNodeNotFound(output_name.clone()))?; + + let source_node = Node::<'model>::from_node( + Cow::Borrowed(output_node), + &value_shapes, + &node_definitions_by_output, + &mut nodes_by_output_name, + )?; + + let output_index = output_node + .get_output() + .iter() + .position(|s| s == output_name) + .ok_or_else(|| IrError::OutputNodeNotFound(output_name.clone()))?; + + Ok(Input { + source_node, + output_index, + }) + }) + .collect(); + + Ok(Arc::new(Node { + definition: NodeDefinition::Outputs { + names: output_names, + }, + inputs: output_nodes?, + })) + } +} + +/// Support for creating [`Session`] from ONNX model files. +/// +/// # Examples +/// +/// Basic usage: +/// +/// ```ignore +/// let mut session = Session::from_path("path/to/model.onnx").await.unwrap(); +/// ``` +impl Session { + // Read an ONNX model from a path and create a session, using default [session config](SessionConfig). + pub async fn from_path>(path: P) -> Result { + let model = onnx::ModelProto::parse_from_bytes(&std::fs::read(path)?)?; + Session::from_model(model).await + } + + // Read an ONNX model from a path and create a session using the specified [session config](SessionConfig). + pub async fn from_path_with_config>( + path: P, + config: &SessionConfig, + ) -> Result { + let model = onnx::ModelProto::parse_from_bytes(&std::fs::read(path)?)?; + Session::from_model_with_config(model, config).await + } + + /// Read an ONNX model from bytes and create a session, using default [session config](SessionConfig). + pub async fn from_bytes(bytes: &[u8]) -> Result { + let model = onnx::ModelProto::parse_from_bytes(bytes)?; + Session::from_model(model).await + } + + /// Read an ONNX model from bytes and create a session with the specified [session config](SessionConfig). + pub async fn from_bytes_with_config( + bytes: &[u8], + config: &SessionConfig, + ) -> Result { + let model = onnx::ModelProto::parse_from_bytes(bytes)?; + Session::from_model_with_config(model, config).await + } + + /// Create a session using the provided [`onnx::ModelProto`] and [session config](SessionConfig). + pub async fn from_model_with_config( + model: onnx::ModelProto, + config: &SessionConfig, + ) -> Result { + let (device, queue) = request_device_queue().await; + + // Optimize and compile the model graph to a set of buffers and 'builders' which can basically run GPU shader code referencing these buffers + let onnx_opset_version = get_opset_version(&model) + .map_err(SessionError::OpsetError)? + .ok_or(SessionError::UnknownOnnxOpsetVersion)?; + + let mut optimizer = Optimizer::new(onnx_opset_version); + let ir = optimizer + .optimize(Node::from_model(&model, config.outputs.as_deref())?) + .await?; + let gpu_model = GpuModel::from(ir, device, queue, onnx_opset_version)?; + + Ok(Session { gpu_model }) + } + + /// Create a Session given an ONNX model, using default configuration. + pub async fn from_model(model: onnx::ModelProto) -> Result { + Self::from_model_with_config(model, &SessionConfig::new()).await + } +} + +#[cfg(test)] +mod tests { + use crate::onnx_model::{ + onnx_attribute, onnx_graph, onnx_initializer, onnx_model, onnx_node, onnx_tensor, + }; + use crate::tensor::TensorData; + + #[test] + fn test_use_onnx_model() { + // USER INPUT + let n = 5; + let c = 1; + let mut input_data = std::collections::HashMap::new(); + + let data: Vec = (0..25).map(|x| x as f32).collect(); + input_data.insert("X".to_string(), data.as_slice().into()); + + // ONNX INPUTS + let shape = vec![1, c, n, n]; + let kernel_n = 3; + let m = 1; + let data_w: Vec = (0..m * c * kernel_n * kernel_n).map(|_| 1.0f32).collect(); + let conv_model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &[1, 1, 3, 3])], + vec![], + vec![onnx_initializer("W", data_w, vec![m, c, 3, 3])], + vec![onnx_node( + vec!["X", "W"], + vec!["Y"], + "Conv", + vec![onnx_attribute("kernel_shape", vec![3, 3])], + )], + )); + + // LOGIC + + let session = pollster::block_on(crate::Session::from_model(conv_model)) + .expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + assert_eq!( + result["Y"], + TensorData::F32(vec![54., 63., 72., 99., 108., 117., 144., 153., 162.].into()) + ); + } +} diff --git a/wonnx/src/optimizer.rs b/wonnx/src/optimizer.rs index 7c068857..a6fc374d 100644 --- a/wonnx/src/optimizer.rs +++ b/wonnx/src/optimizer.rs @@ -1,20 +1,17 @@ //! Optimizer that walks the DAG and transforms or coalesces ops for quicker execution use crate::{ gpu::GpuModel, - ir::{Input, Node, NodeDefinition, NodeIdentifier, OperatorDefinition}, - onnx::{NodeProto, TensorProto}, - resource::{padding, request_device_queue}, - utils::{ - attribute, AttributeNotFoundError, DataTypeError, NodeAttributes, OutputTensor, ScalarType, - Shape, + ir::{ + AttributeNotFoundError, Input, Node, NodeDefinition, NodeIdentifier, OperatorDefinition, + Tensor, }, + resource::{padding, request_device_queue}, + tensor::{DataTypeError, ScalarType, TensorData}, GpuError, }; use async_recursion::async_recursion; -use bytemuck::pod_collect_to_vec; -use protobuf::RepeatedField; use std::{ - borrow::Cow, + borrow::{Borrow, Cow}, collections::{HashMap, VecDeque}, sync::Arc, }; @@ -49,7 +46,7 @@ pub enum OptimizerError { } pub struct Optimizer<'model> { - padded_tensors: HashMap>>, + padded_tensors: HashMap, Arc>>, optimized: HashMap, Arc>>, onnx_opset_version: i64, } @@ -73,7 +70,7 @@ impl<'model> Optimizer<'model> { match node.definition() { NodeDefinition::Operator(op_def) => { // TODO: constant nodes with multiple outputs - if op_def.proto.output.len() != 1 { + if op_def.output_shapes().len() != 1 { log::warn!( "node {:?} is constant, but has multiple outputs, which we can't fold yet", node.definition() @@ -81,28 +78,26 @@ impl<'model> Optimizer<'model> { return Ok(None); } - match op_def.proto.get_op_type() { + match op_def.get_op_type() { "Constant" => Ok(Some(Arc::new(Node { - definition: NodeDefinition::Tensor(Box::new(Cow::Owned( - Self::constant_node_to_tensor(node)?, - ))), + definition: NodeDefinition::Tensor(Self::constant_node_to_tensor(node)?), inputs: vec![], }))), _ => self.infer_constant_node_to_tensor(node.clone()).await, } } NodeDefinition::Tensor(_) => Ok(None), // already constantized - NodeDefinition::Input(_) | NodeDefinition::Missing => unreachable!(), + NodeDefinition::Input { .. } | NodeDefinition::Missing => unreachable!(), NodeDefinition::Outputs { .. } => Ok(None), // all the outputs themselves are already constant, so nothing to do } } // Takes a node with operator type 'Shape' and returns its output as a tensor - fn shape_node_to_tensor(node: Arc>) -> Result { + fn shape_node_to_tensor(node: Arc>) -> Result, OptimizerError> { let NodeDefinition::Operator(op_def) = node.definition() else { panic!("node must be a Shape node"); }; - assert_eq!(op_def.proto.get_op_type(), "Shape"); + assert_eq!(op_def.get_op_type(), "Shape"); if node.inputs.len() != 1 { return Err(OptimizerError::InvalidNode(format!( @@ -115,15 +110,11 @@ impl<'model> Optimizer<'model> { let input = &node.inputs[0]; let in_node = &input.source_node.definition; let in_shape = match in_node { - NodeDefinition::Input(input) => input.get_shape()?, + NodeDefinition::Input { shape, .. } => shape.clone(), NodeDefinition::Operator(input_op_def) => { - input_op_def.output_shapes[input.output_index].clone() + input_op_def.output_shapes()[input.output_index].clone() } - NodeDefinition::Tensor(input_tensor) => Shape::from( - ScalarType::from_i32(input_tensor.get_data_type()) - .map_err(OptimizerError::InvalidDataType)?, - input_tensor.get_dims(), - ), + NodeDefinition::Tensor(input_tensor) => input_tensor.shape(), NodeDefinition::Outputs { .. } => { return Err(OptimizerError::Unsupported( "output node cannot be used as an input to Shape node".to_string(), @@ -136,8 +127,8 @@ impl<'model> Optimizer<'model> { } }; let rank = in_shape.rank() as i64; - let mut start: i64 = op_def.proto.get_attribute_value("start", Some(0)).unwrap(); - let mut end: i64 = op_def.proto.get_attribute_value("end", Some(rank)).unwrap(); + let mut start: i64 = op_def.get_attribute_value("start", Some(0)).unwrap(); + let mut end: i64 = op_def.get_attribute_value("end", Some(rank)).unwrap(); if start < 0 { start += rank; } @@ -163,48 +154,65 @@ impl<'model> Optimizer<'model> { .iter() .map(|x| *x as i64) .collect(); - let dims = vec![values.len() as i64]; - Ok(TensorProto::from(OutputTensor::I64(values), dims)) + let dims = vec![values.len()]; + Ok(Tensor { + data: TensorData::I64(values.into()), + dims, + display_name: format!("{}", node.definition().get_display_name()), + }) } // Takes a node with operator type 'Constant' and returns its output as a tensor - fn constant_node_to_tensor(node: Arc>) -> Result { + fn constant_node_to_tensor(node: Arc>) -> Result, OptimizerError> { let NodeDefinition::Operator(op_def) = node.definition() else { panic!("node must be a Constant node"); }; - assert_eq!(op_def.proto.get_op_type(), "Constant"); - let proto = &op_def.proto; - let output_name = proto.output.get(0).unwrap().to_owned(); - - let mut tp: TensorProto = - if let Ok(values) = proto.get_attribute_value::>("value_floats", None) { - let dims = vec![values.len() as i64]; - TensorProto::from(OutputTensor::F32(values), dims) - } else if let Ok(values) = proto.get_attribute_value::>("value_ints", None) { - let dims = vec![values.len() as i64]; - TensorProto::from(OutputTensor::I64(values), dims) - } else if let Ok(value) = proto.get_attribute_value::("value_float", None) { - TensorProto::from(OutputTensor::F32(vec![value]), vec![1]) - } else if let Ok(value) = proto.get_attribute_value::("value_int", None) { - TensorProto::from(OutputTensor::I64(vec![value]), vec![1]) - } else if let Ok(tp) = proto.get_attribute_value::("value", None) { - tp + assert_eq!(op_def.get_op_type(), "Constant"); + let display_name = op_def.get_display_name().into(); + + let tp: Tensor = + if let Ok(values) = op_def.get_attribute_value::>("value_floats", None) { + let dims = vec![values.len()]; + Tensor { + data: TensorData::F32(values.into()), + dims, + display_name, + } + } else if let Ok(values) = op_def.get_attribute_value::>("value_ints", None) { + let dims = vec![values.len()]; + Tensor { + data: TensorData::I64(values.into()), + dims, + display_name, + } + } else if let Ok(value) = op_def.get_attribute_value::("value_float", None) { + Tensor { + data: TensorData::F32(vec![value].into()), + dims: vec![1], + display_name, + } + } else if let Ok(value) = op_def.get_attribute_value::("value_int", None) { + Tensor { + data: TensorData::I64(vec![value].into()), + dims: vec![1], + display_name, + } + } else if let Ok(t) = op_def.get_attribute_value::("value", None) { + t.into_static() } else { return Err(OptimizerError::Unsupported( "Constant node with unknown value type".to_string(), )); }; - - tp.set_name(output_name); Ok(tp) } // Takes a node with operator type 'Size' and returns its output as a tensor - fn size_node_to_tensor(node: Arc>) -> Result { + fn size_node_to_tensor(node: Arc>) -> Result, OptimizerError> { let NodeDefinition::Operator(op_def) = node.definition() else { panic!("node must be a Size node"); }; - assert_eq!(op_def.proto.get_op_type(), "Size"); + assert_eq!(op_def.get_op_type(), "Size"); if node.inputs.len() != 1 { return Err(OptimizerError::InvalidNode(format!( @@ -217,11 +225,11 @@ impl<'model> Optimizer<'model> { let input = &node.inputs[0]; let in_node = &input.source_node.definition; let in_element_count: i64 = match in_node { - NodeDefinition::Input(input) => input.get_shape()?.element_count() as i64, + NodeDefinition::Input { shape, .. } => shape.element_count() as i64, NodeDefinition::Operator(input_op_def) => { - input_op_def.output_shapes[input.output_index].element_count() as i64 + input_op_def.output_shapes()[input.output_index].element_count() as i64 } - NodeDefinition::Tensor(input_tensor) => input_tensor.get_dims().iter().product(), + NodeDefinition::Tensor(input_tensor) => input_tensor.shape().element_count() as i64, NodeDefinition::Outputs { .. } => { return Err(OptimizerError::Unsupported( "output node cannot be used as an input to Shape node".to_string(), @@ -234,10 +242,11 @@ impl<'model> Optimizer<'model> { } }; - Ok(TensorProto::from( - OutputTensor::I64(vec![in_element_count]), - vec![1], - )) + Ok(Tensor { + data: TensorData::I64(vec![in_element_count].into()), + dims: vec![1], + display_name: format!("{}", node.definition().get_display_name()), + }) } // Infers the output for a constant node (must be a constant and operator node, or the function panics) @@ -249,8 +258,6 @@ impl<'model> Optimizer<'model> { // Create an output node so we can perform inference for this node if let NodeDefinition::Operator(op_def) = node.definition() { - let output_name = op_def.proto.output.get(0).unwrap().to_owned(); - let out_node = Arc::new(Node { definition: NodeDefinition::Outputs { names: vec!["output".to_string()], @@ -269,19 +276,18 @@ impl<'model> Optimizer<'model> { // Take the output tensor and make it into an initializer node let (_, output_tensor) = outputs.drain().take(1).next().unwrap(); - log::info!("folded {output_name} to {output_tensor:?}"); - let mut output_tensor_proto = TensorProto::from( - output_tensor, - op_def.output_shapes[0] - .dims - .iter() - .map(|x| *x as i64) - .collect(), + log::info!( + "folded output of {} to {output_tensor:?}", + op_def.get_display_name() ); - output_tensor_proto.set_name(output_name); + let shape = op_def.output_shapes()[0].clone(); let tensor_node = Node { - definition: NodeDefinition::Tensor(Box::new(Cow::Owned(output_tensor_proto))), + definition: NodeDefinition::Tensor(Tensor { + data: output_tensor.into_static(), + dims: shape.dims, + display_name: format!("{}", op_def.get_display_name().to_owned()), + }), inputs: vec![], }; @@ -441,20 +447,16 @@ impl<'model> Optimizer<'model> { // Fold Shape/Size nodes (not considered constant but we can still fold it) if let NodeDefinition::Operator(op_def) = &node.definition { - match op_def.proto.get_op_type() { + match op_def.get_op_type() { "Shape" => { return Ok(Arc::new(Node { - definition: NodeDefinition::Tensor(Box::new(Cow::Owned( - Self::shape_node_to_tensor(node)?, - ))), + definition: NodeDefinition::Tensor(Self::shape_node_to_tensor(node)?), inputs: vec![], })) } "Size" => { return Ok(Arc::new(Node { - definition: NodeDefinition::Tensor(Box::new(Cow::Owned( - Self::size_node_to_tensor(node)?, - ))), + definition: NodeDefinition::Tensor(Self::size_node_to_tensor(node)?), inputs: vec![], })) } @@ -476,81 +478,110 @@ impl<'model> Optimizer<'model> { match &node.definition { NodeDefinition::Operator(op_def) => { - match op_def.proto.get_op_type() { + match op_def.get_op_type() { "Conv" | "ConvRelu" | "ConvLeakyRelu" => { // This optimization inserts some padding to convolution between kernels with kernel 3x3, because of // the stride of matrix3x3 is 16 in wgsl. It makes the computation matrixable and increases the performance. if new_inputs.len() > 2 - && op_def - .proto - .get_attribute_value::>("kernel_shape", None)? + && op_def.get_attribute_value::>("kernel_shape", None)? == [3, 3] - && (op_def - .proto - .get_attribute_value("pads", Some(vec![0, 0, 0, 0]))? + && (op_def.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))? == [1, 1, 1, 1] - || op_def.proto.get_attribute_value( + || op_def.get_attribute_value( "auto_pad", Some("SAME_UPPER".to_string()), )? == "SAME_UPPER") - && op_def - .proto - .get_attribute_value("strides", Some(vec![1, 1]))? - == [1, 1] - && op_def.proto.get_attribute_value("group", Some(1))? == 1 - && op_def.output_shapes[0].dim(1) % 4 == 0 + && op_def.get_attribute_value("strides", Some(vec![1, 1]))? == [1, 1] + && op_def.get_attribute_value("group", Some(1))? == 1 + && op_def.output_shapes()[0].dim(1) % 4 == 0 { - if let NodeDefinition::Tensor(tensor) = - &new_inputs[1].source_node.definition - { - new_inputs[1] = Input { - output_index: 0, - source_node: match self.padded_tensors.get(tensor.get_name()) { - Some(padded_tensor_node) => padded_tensor_node.clone(), - None => { - let data = tensor.get_float_data(); - let raw_data = if !data.is_empty() { - bytemuck::cast_slice(data) - } else { - tensor.get_raw_data() - }; - - let padded_raw_data = padding(raw_data, 12, 4); - - log::info!( - "applying padding optimization to tensor {}: strides data is {} bytes before, {} bytes after", - tensor.get_name(), - raw_data.len(), - padded_raw_data.len() - ); - - // Create a new tensor with the padded data - let mut new_tensor = tensor.clone().into_owned(); - new_tensor.set_float_data(vec![]); - new_tensor.set_raw_data(padded_raw_data); - let new_node = Arc::new(Node { - definition: NodeDefinition::Tensor(Box::new( - Cow::Owned(new_tensor), - )), - inputs: vec![], - }); - self.padded_tensors.insert( - tensor.get_name().to_string(), - new_node.clone(), - ); - new_node + let source_node = { + if let NodeDefinition::Tensor(tensor) = + &new_inputs[1].source_node.definition + { + let source_identifier = new_inputs[1].source_node.identifier(); + if let Some(n) = self.padded_tensors.get(&source_identifier) { + log::info!( + "have cached padding optimized tensor for {:?}", + source_identifier + ); + n.clone() + } else { + match tensor.data() { + TensorData::F32(floats) => { + let raw_data: &[u8] = + bytemuck::cast_slice(floats.borrow()); + let padded_raw_data = padding(raw_data, 12, 4); + log::info!( + "applying padding optimization to tensor {} shape {}: strides data is {} bytes before, {} bytes after", + tensor.display_name(), + tensor.shape(), + raw_data.len(), + padded_raw_data.len() + ); + + let padded_data = + bytemuck::pod_collect_to_vec(&padded_raw_data); + + // Create a new tensor with the padded data + let new_node = Arc::new(Node { + definition: NodeDefinition::Tensor(Tensor { + data: TensorData::F32(Cow::Owned( + padded_data, + )), + dims: tensor.dims().to_vec(), + display_name: format!( + "{}", + tensor.display_name() + ), + }), + inputs: vec![], + }); + self.padded_tensors.insert( + source_identifier.clone(), + new_node.clone(), + ); + new_node + } + _ => { + log::warn!("not applying padding optimization as source tensor does not have float data"); + return Ok(node.clone()); + } } - }, + } + } else { + return Ok(Arc::new(Node { + inputs: new_inputs, + definition: node.definition().clone(), + })); } - } + }; + + new_inputs[1] = Input { + source_node: source_node.clone(), + output_index: 0, + }; + + let new_node = Node { + inputs: new_inputs, + definition: NodeDefinition::Operator(op_def.clone()), + }; + log::info!( + "actually returning new node {} with new input 1 padded {}", + node.definition().get_display_name(), + source_node.definition().get_display_name() + ); + Ok(Arc::new(new_node)) + } else { + log::info!( + "actually returning old node {}", + node.definition().get_display_name() + ); + Ok(Arc::new(Node { + inputs: new_inputs, + definition: node.definition().clone(), + })) } - - let new_node = Node { - inputs: new_inputs, - definition: NodeDefinition::Operator(op_def.clone()), - }; - - Ok(Arc::new(new_node)) } // The Clip, Split, Resize, Reshape and Reduce* operators each take optional inputs that influence the operation. @@ -585,18 +616,16 @@ impl<'model> Optimizer<'model> { }; // Make a new copy of the attributes list (we're going to add attributes) - let mut new_proto = op_def.proto.clone().into_owned(); - let mut attributes = op_def.proto.get_attribute().to_vec(); + let mut new_proto = op_def.clone(); // Loop over the inputs (skipping the first one - that's going to be the data input) for input_index in 1..(new_inputs.len().min(attr_names.len())) { let source_node = &new_inputs[input_index].source_node; match &source_node.definition { // If the input is an initializer (Tensor) we can obtain the data from the definition and move it to an attribute - NodeDefinition::Tensor(tensor_proto) => { + NodeDefinition::Tensor(tensor) => { let attr_name = attr_names[input_index]; - let data_type = - ScalarType::from_i32(tensor_proto.get_data_type())?; + let data_type = tensor.shape().data_type; match (op, attr_name) { ("Split", "split") @@ -612,16 +641,8 @@ impl<'model> Optimizer<'model> { ) | ("Pad", "pads") | ("Resize", "scales") - | ("Clip", "min" | "max") => match data_type { - ScalarType::F32 => { - let value: Vec = if tensor_proto - .get_float_data() - .is_empty() - { - pod_collect_to_vec(tensor_proto.get_raw_data()) - } else { - tensor_proto.get_float_data().to_vec() - }; + | ("Clip", "min" | "max") => match tensor.data() { + TensorData::F32(value) => { log::info!( "transferring input {} for op {} to f32 attribute (initializer data type: {:?}): {:?}", attr_name, @@ -629,20 +650,12 @@ impl<'model> Optimizer<'model> { data_type, value, ); - attributes.push(attribute( + new_proto.set_attribute( attr_names[input_index], - value, - )); + value.to_vec(), + ); } - ScalarType::I64 => { - let value = if tensor_proto - .get_int64_data() - .is_empty() - { - pod_collect_to_vec(tensor_proto.get_raw_data()) - } else { - tensor_proto.get_int64_data().to_vec() - }; + TensorData::I64(value) => { log::info!( "transferring input {} for op {} to i64 attribute (initializer data type: {:?}): {:?}", attr_name, @@ -650,10 +663,10 @@ impl<'model> Optimizer<'model> { data_type, value, ); - attributes.push(attribute( + new_proto.set_attribute( attr_names[input_index], - value, - )); + value.to_vec(), + ); } _ => { return Err(OptimizerError::InvalidInputDataType { @@ -667,7 +680,7 @@ impl<'model> Optimizer<'model> { // Some other unspecified input that we do not support yet return Err(OptimizerError::Unsupported(format!( "data_type {} for input {} to op {}", - tensor_proto.get_data_type(), + tensor.shape().data_type, attr_name, op ))); @@ -687,15 +700,9 @@ impl<'model> Optimizer<'model> { } } - // Create new node with extra attributes - new_proto.set_attribute(RepeatedField::from(attributes)); - let new_node = Node { inputs: vec![new_inputs[0].clone()], - definition: NodeDefinition::Operator(Box::new(OperatorDefinition { - proto: Cow::Owned(new_proto), - output_shapes: op_def.output_shapes.clone(), - })), + definition: NodeDefinition::Operator(new_proto), }; Ok(Arc::new(new_node)) @@ -707,7 +714,7 @@ impl<'model> Optimizer<'model> { })), } } - NodeDefinition::Tensor(..) | NodeDefinition::Input(..) => { + NodeDefinition::Tensor(..) | NodeDefinition::Input { .. } => { assert!( new_inputs.is_empty(), "non-operator node cannot have inputs" @@ -732,14 +739,14 @@ impl<'model> Optimizer<'model> { ) -> Result { // Start by throwing out all Identity nodes chain.retain(|n| match &n.definition { - NodeDefinition::Operator(op_def) => op_def.proto.get_op_type() != "Identity", + NodeDefinition::Operator(op_def) => op_def.get_op_type() != "Identity", _ => true, }); let names: Vec<&str> = chain .iter() .map(|x| match &x.definition { - NodeDefinition::Operator(op_def) => op_def.proto.get_op_type(), + NodeDefinition::Operator(op_def) => op_def.get_op_type(), _ => "", }) .collect(); @@ -763,38 +770,28 @@ impl<'model> Optimizer<'model> { (&conv.definition, &relu.definition) { // Use the Conv node as template for the new fused Conv[Leaky]Relu node - let mut convrelu_def = *conv_def.clone(); - let mut convrelu_proto = conv_def.proto.clone().into_owned(); - let new_op_type = match relu_def.proto.get_op_type() { + let mut convrelu_def = conv_def.clone(); + let new_op_type = match relu_def.get_op_type() { "LeakyRelu" => "ConvLeakyRelu", "Relu" => "ConvRelu", _ => unreachable!(), }; - convrelu_proto.set_op_type(new_op_type.to_string()); + convrelu_def.set_op_type(new_op_type); // Copy all Relu attributes over to the copy of the Conv node - let mut attributes = conv_def.proto.get_attribute().to_vec(); - attributes.extend(relu_def.proto.get_attribute().iter().cloned()); - convrelu_proto.set_attribute(RepeatedField::from(attributes)); - convrelu_proto.set_name(format!( - "{}+{}", - conv.definition.get_name(), - relu.definition.get_name() - )); + convrelu_def.append_attributes_from(relu_def); log::debug!( "can fuse chain of Conv/[Leaky]Relu to Conv[Leaky]Relu: {:?}: {:?} + {:?} = {}", names, conv.definition(), relu.definition(), - convrelu_proto.get_name() + convrelu_def.get_display_name() ); - convrelu_def.proto = Cow::Owned(convrelu_proto); - let node = Arc::new(Node { inputs: conv.inputs.clone(), - definition: NodeDefinition::Operator(Box::new(convrelu_def)), + definition: NodeDefinition::Operator(convrelu_def), }); chain.remove(0); @@ -820,56 +817,19 @@ static PAD_INPUT_NAMES: &[&str] = &["data", "pads", "constant_value"]; /// Generate the output for a ConstantOfShape node pub fn constant_of_shape_output( - node: &NodeProto, + node: &OperatorDefinition, element_count: usize, -) -> Result { - if let Ok(constant_value_tensor) = node.get_attribute_value::("value", None) { - match ScalarType::from_i32(constant_value_tensor.get_data_type()).map_err(|_| { - OptimizerError::Unsupported(format!( - "unsupported data type {}", - constant_value_tensor.get_data_type() - )) - })? { - ScalarType::F32 => { - let fd = constant_value_tensor.get_float_data(); - if fd.is_empty() { - return Err(OptimizerError::InvalidNode( - "value tensor for ConstantOfShape is empty".to_string(), - )); - } - Ok(OutputTensor::F32(vec![fd[0]; element_count])) - } - ScalarType::I64 => { - let fd = constant_value_tensor.get_int64_data(); - if fd.is_empty() { - return Err(OptimizerError::InvalidNode( - "value tensor for ConstantOfShape is empty".to_string(), - )); - } - Ok(OutputTensor::I64(vec![fd[0]; element_count])) - } - ScalarType::I32 => { - let fd = constant_value_tensor.get_int32_data(); - if fd.is_empty() { - return Err(OptimizerError::InvalidNode( - "value tensor for ConstantOfShape is empty".to_string(), - )); - } - Ok(OutputTensor::I32(vec![fd[0]; element_count])) - } - ScalarType::U8 => { - let fd = constant_value_tensor.get_raw_data(); - if fd.is_empty() { - return Err(OptimizerError::InvalidNode( - "value tensor for ConstantOfShape is empty".to_string(), - )); - } - Ok(OutputTensor::U8(vec![fd[0]; element_count])) - } +) -> Result, OptimizerError> { + if let Ok(constant_value_tensor) = node.get_attribute_value::("value", None) { + match constant_value_tensor.data() { + TensorData::F32(f) => Ok(TensorData::F32(vec![f[0]; element_count].into())), + TensorData::I32(f) => Ok(TensorData::I32(vec![f[0]; element_count].into())), + TensorData::I64(f) => Ok(TensorData::I64(vec![f[0]; element_count].into())), + TensorData::U8(f) => Ok(TensorData::U8(vec![f[0]; element_count].into())), } } else { // The default value is a zero f32 - Ok(OutputTensor::F32(vec![0.0; element_count])) + Ok(TensorData::F32(vec![0.0; element_count].into())) } } @@ -880,7 +840,10 @@ mod test { use crate::{ ir::{self, Node, NodeDefinition}, onnx::AttributeProto, - utils::{attribute, graph, initializer, model, node, tensor}, + onnx_model::{ + onnx_attribute, onnx_graph, onnx_initializer, onnx_model, onnx_node, onnx_tensor, + }, + tensor::TensorData, }; use super::Optimizer; @@ -890,9 +853,9 @@ mod test { NodeDefinition::Outputs { .. } => String::from(""), NodeDefinition::Missing => String::from(""), NodeDefinition::Operator(op_def) => { - format!("{}_{}", op_def.proto.get_op_type(), op_def.proto.get_name()) + format!("{}_{}", op_def.get_op_type(), op_def.get_display_name()) } - d => format!("{}", d.get_name()), + d => format!("{}", d.get_display_name()), } } @@ -913,14 +876,14 @@ mod test { pub fn test_optimize_identity_identity() { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( - vec![tensor("X", &[1])], - vec![tensor("Y", &[1])], - vec![tensor("A", &[1])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1])], + vec![onnx_tensor("Y", &[1])], + vec![onnx_tensor("A", &[1])], vec![], vec![ - node(vec!["X"], vec!["A"], "a", "Identity", vec![]), - node(vec!["A"], vec!["Y"], "b", "Identity", vec![]), + onnx_node(vec!["X"], vec!["A"], "Identity", vec![]), + onnx_node(vec!["A"], vec!["Y"], "Identity", vec![]), ], )); @@ -938,14 +901,14 @@ mod test { pub fn test_optimize_neg_neg() { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( - vec![tensor("X", &[1])], - vec![tensor("Y", &[1])], - vec![tensor("A", &[1])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1])], + vec![onnx_tensor("Y", &[1])], + vec![onnx_tensor("A", &[1])], vec![], vec![ - node(vec!["X"], vec!["A"], "a", "Neg", vec![]), - node(vec!["A"], vec!["Y"], "b", "Neg", vec![]), + onnx_node(vec!["X"], vec!["A"], "Neg", vec![]), + onnx_node(vec!["A"], vec!["Y"], "Neg", vec![]), ], )); @@ -964,15 +927,15 @@ mod test { pollster::block_on(async { let _ = env_logger::builder().is_test(true).try_init(); - let m = model(graph( - vec![tensor("X", &[1])], - vec![tensor("Y", &[1])], - vec![tensor("A", &[1]), tensor("B", &[1])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1])], + vec![onnx_tensor("Y", &[1])], + vec![onnx_tensor("A", &[1]), onnx_tensor("B", &[1])], vec![], vec![ - node(vec!["X"], vec!["A"], "a", "Neg", vec![]), - node(vec!["A"], vec!["B"], "b", "Neg", vec![]), - node(vec!["B"], vec!["Y"], "c", "Neg", vec![]), + onnx_node(vec!["X"], vec!["A"], "Neg", vec![]), + onnx_node(vec!["A"], vec!["B"], "Neg", vec![]), + onnx_node(vec!["B"], vec!["Y"], "Neg", vec![]), ], )); @@ -984,8 +947,8 @@ mod test { assert_eq!( new_pairs, vec![ - ("Neg_c".to_string(), "".to_string()), - ("X".to_string(), "Neg_c".to_string()) + ("Neg_Y".to_string(), "".to_string()), + ("X".to_string(), "Neg_Y".to_string()) ] ); }); @@ -996,16 +959,20 @@ mod test { pub fn test_optimize_4neg() { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( - vec![tensor("X", &[1])], - vec![tensor("Y", &[1])], - vec![tensor("A", &[1]), tensor("B", &[1]), tensor("C", &[1])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1])], + vec![onnx_tensor("Y", &[1])], + vec![ + onnx_tensor("A", &[1]), + onnx_tensor("B", &[1]), + onnx_tensor("C", &[1]), + ], vec![], vec![ - node(vec!["X"], vec!["A"], "a", "Neg", vec![]), - node(vec!["A"], vec!["B"], "b", "Neg", vec![]), - node(vec!["B"], vec!["C"], "c", "Neg", vec![]), - node(vec!["C"], vec!["Y"], "d", "Neg", vec![]), + onnx_node(vec!["X"], vec!["A"], "Neg", vec![]), + onnx_node(vec!["A"], vec!["B"], "Neg", vec![]), + onnx_node(vec!["B"], vec!["C"], "Neg", vec![]), + onnx_node(vec!["C"], vec!["Y"], "Neg", vec![]), ], )); @@ -1023,22 +990,22 @@ mod test { pub fn test_optimize_5neg() { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( - vec![tensor("X", &[1])], - vec![tensor("Y", &[1])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1])], + vec![onnx_tensor("Y", &[1])], vec![ - tensor("A", &[1]), - tensor("B", &[1]), - tensor("C", &[1]), - tensor("D", &[1]), + onnx_tensor("A", &[1]), + onnx_tensor("B", &[1]), + onnx_tensor("C", &[1]), + onnx_tensor("D", &[1]), ], vec![], vec![ - node(vec!["X"], vec!["A"], "a", "Neg", vec![]), - node(vec!["A"], vec!["B"], "b", "Neg", vec![]), - node(vec!["B"], vec!["C"], "c", "Neg", vec![]), - node(vec!["C"], vec!["D"], "d", "Neg", vec![]), - node(vec!["D"], vec!["Y"], "e", "Neg", vec![]), + onnx_node(vec!["X"], vec!["A"], "Neg", vec![]), + onnx_node(vec!["A"], vec!["B"], "Neg", vec![]), + onnx_node(vec!["B"], vec!["C"], "Neg", vec![]), + onnx_node(vec!["C"], vec!["D"], "Neg", vec![]), + onnx_node(vec!["D"], vec!["Y"], "Neg", vec![]), ], )); @@ -1050,8 +1017,8 @@ mod test { assert_eq!( new_pairs, vec![ - ("Neg_e".to_string(), "".to_string()), - ("X".to_string(), "Neg_e".to_string()) + ("Neg_Y".to_string(), "".to_string()), + ("X".to_string(), "Neg_Y".to_string()) ] ); }); @@ -1062,14 +1029,14 @@ mod test { pub fn test_optimize_neg_neg_branch() { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( - vec![tensor("X", &[1])], - vec![tensor("Y", &[1]), tensor("A", &[1])], - vec![tensor("A", &[1])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1])], + vec![onnx_tensor("Y", &[1]), onnx_tensor("A", &[1])], + vec![onnx_tensor("A", &[1])], vec![], vec![ - node(vec!["X"], vec!["A"], "a", "Neg", vec![]), - node(vec!["A"], vec!["Y"], "b", "Neg", vec![]), + onnx_node(vec!["X"], vec!["A"], "Neg", vec![]), + onnx_node(vec!["A"], vec!["Y"], "Neg", vec![]), ], )); @@ -1082,8 +1049,8 @@ mod test { new_pairs, vec![ ("X".to_string(), "".to_string()), - ("Neg_a".to_string(), "".to_string()), - ("X".to_string(), "Neg_a".to_string()) + ("Neg_A".to_string(), "".to_string()), + ("X".to_string(), "Neg_A".to_string()) ] ); }); @@ -1095,15 +1062,15 @@ mod test { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( - vec![tensor("X", &[1])], - vec![tensor("Y", &[1]), tensor("Z", &[1])], - vec![tensor("A", &[1])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1])], + vec![onnx_tensor("Y", &[1]), onnx_tensor("Z", &[1])], + vec![onnx_tensor("A", &[1])], vec![], vec![ - node(vec!["X"], vec!["A"], "a", "Neg", vec![]), - node(vec!["A"], vec!["Z"], "b", "Identity", vec![]), - node(vec!["A"], vec!["Y"], "c", "Identity", vec![]), + onnx_node(vec!["X"], vec!["A"], "Neg", vec![]), + onnx_node(vec!["A"], vec!["Z"], "Identity", vec![]), + onnx_node(vec!["A"], vec!["Y"], "Identity", vec![]), ], )); @@ -1115,10 +1082,10 @@ mod test { assert_eq!( new_pairs, vec![ - ("Neg_a".to_string(), "".to_string()), - ("Neg_a".to_string(), "".to_string()), - ("X".to_string(), "Neg_a".to_string()), - ("X".to_string(), "Neg_a".to_string()), + ("Neg_A".to_string(), "".to_string()), + ("Neg_A".to_string(), "".to_string()), + ("X".to_string(), "Neg_A".to_string()), + ("X".to_string(), "Neg_A".to_string()), ] ); }); @@ -1130,15 +1097,15 @@ mod test { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( + let m = onnx_model(onnx_graph( vec![], - vec![tensor("C", &[1])], + vec![onnx_tensor("C", &[1])], vec![], vec![ - initializer("A", vec![21.0], vec![1]), - initializer("B", vec![7.0], vec![1]), + onnx_initializer("A", vec![21.0], vec![1]), + onnx_initializer("B", vec![7.0], vec![1]), ], - vec![node(vec!["A", "B"], vec!["C"], "c", "Add", vec![])], + vec![onnx_node(vec!["A", "B"], vec!["C"], "Add", vec![])], )); let root = ir::Node::from_model(&m, None).unwrap(); @@ -1146,7 +1113,10 @@ mod test { let new_root = opt.optimize(root).await.unwrap(); let mut new_pairs = vec![]; traverse(new_root, &mut new_pairs); - assert_eq!(new_pairs, vec![("C".to_string(), "".to_string())]); + assert_eq!( + new_pairs, + vec![("C".to_string(), "".to_string())] + ); }); } @@ -1156,17 +1126,16 @@ mod test { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( + let m = onnx_model(onnx_graph( vec![], - vec![tensor("Y", &[1])], + vec![onnx_tensor("Y", &[1])], vec![], vec![], - vec![node( + vec![onnx_node( vec![], vec!["Y"], - "y", "Constant", - vec![attribute("value_float", 42.0)], + vec![onnx_attribute("value_float", 42.0)], )], )); @@ -1187,18 +1156,18 @@ mod test { pub fn test_shape_operator() { test_shape_operator_with( &[1, 2, 3], - vec![attribute("start", -3), attribute("end", -2)], + vec![onnx_attribute("start", -3), onnx_attribute("end", -2)], &[1], ); test_shape_operator_with(&[1, 2, 3], vec![], &[1, 2, 3]); - test_shape_operator_with(&[3, 4, 5], vec![attribute("start", 0)], &[3, 4, 5]); - test_shape_operator_with(&[3, 4, 5], vec![attribute("start", 1)], &[4, 5]); - test_shape_operator_with(&[3, 4, 5], vec![attribute("start", -1)], &[5]); - test_shape_operator_with(&[3, 4, 5], vec![attribute("end", 10)], &[3, 4, 5]); - test_shape_operator_with(&[3, 4, 5], vec![attribute("end", 1)], &[3]); + test_shape_operator_with(&[3, 4, 5], vec![onnx_attribute("start", 0)], &[3, 4, 5]); + test_shape_operator_with(&[3, 4, 5], vec![onnx_attribute("start", 1)], &[4, 5]); + test_shape_operator_with(&[3, 4, 5], vec![onnx_attribute("start", -1)], &[5]); + test_shape_operator_with(&[3, 4, 5], vec![onnx_attribute("end", 10)], &[3, 4, 5]); + test_shape_operator_with(&[3, 4, 5], vec![onnx_attribute("end", 1)], &[3]); test_shape_operator_with( &[3, 4, 5], - vec![attribute("start", 10), attribute("end", 10)], + vec![onnx_attribute("start", 10), onnx_attribute("end", 10)], &[], ); } @@ -1211,12 +1180,12 @@ mod test { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( - vec![tensor("X", input_shape)], - vec![tensor("Y", &[expected.len() as i64])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", input_shape)], + vec![onnx_tensor("Y", &[expected.len() as i64])], vec![], vec![], - vec![node(vec!["X"], vec!["Y"], "y", "Shape", attrs)], + vec![onnx_node(vec!["X"], vec!["Y"], "Shape", attrs)], )); let root = ir::Node::from_model(&m, None).unwrap(); @@ -1224,13 +1193,16 @@ mod test { let new_root = opt.optimize(root).await.unwrap(); let mut new_pairs = vec![]; traverse(new_root.clone(), &mut new_pairs); - assert_eq!(new_pairs, vec![("".to_string(), "".to_string())]); + assert_eq!( + new_pairs, + vec![("Y".to_string(), "".to_string())] + ); let y_node = new_root.inputs[0].source_node.clone(); let NodeDefinition::Tensor(t) = y_node.definition() else { panic!("should be folded to an initializer"); }; - assert_eq!(t.get_int64_data(), expected); + assert_eq!(t.data(), &TensorData::I64(expected.into())); }); } @@ -1246,12 +1218,12 @@ mod test { let _ = env_logger::builder().is_test(true).try_init(); pollster::block_on(async { - let m = model(graph( - vec![tensor("X", input_shape)], - vec![tensor("Y", &[expected.len() as i64])], + let m = onnx_model(onnx_graph( + vec![onnx_tensor("X", input_shape)], + vec![onnx_tensor("Y", &[expected.len() as i64])], vec![], vec![], - vec![node(vec!["X"], vec!["Y"], "y", "Size", vec![])], + vec![onnx_node(vec!["X"], vec!["Y"], "Size", vec![])], )); let root = ir::Node::from_model(&m, None).unwrap(); @@ -1259,13 +1231,16 @@ mod test { let new_root = opt.optimize(root).await.unwrap(); let mut new_pairs = vec![]; traverse(new_root.clone(), &mut new_pairs); - assert_eq!(new_pairs, vec![("".to_string(), "".to_string())]); + assert_eq!( + new_pairs, + vec![("Y".to_string(), "".to_string())] + ); let y_node = new_root.inputs[0].source_node.clone(); let NodeDefinition::Tensor(t) = y_node.definition() else { panic!("should be folded to an initializer"); }; - assert_eq!(t.get_int64_data(), expected); + assert_eq!(t.data(), &TensorData::I64(expected.into())); }); } } diff --git a/wonnx/src/tensor.rs b/wonnx/src/tensor.rs new file mode 100644 index 00000000..50277453 --- /dev/null +++ b/wonnx/src/tensor.rs @@ -0,0 +1,413 @@ +//! Various basic data types +use num::FromPrimitive; +use serde::Serialize; +use std::borrow::Cow; +use std::convert::From; +use std::convert::TryFrom; +use std::fmt::Display; +use thiserror::Error; + +/* Minimum size of a buffer you can create with wgpu. Creating buffers smaller than this leads to panic "Validation +* error: buffer binding size X is less than minimum 64" in Device::create_bind_group */ +pub(crate) const MINIMUM_BUFFER_SIZE_BYTES: u64 = 64; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Shape { + pub dims: Vec, + pub data_type: ScalarType, +} + +impl Shape { + pub fn from(data_type: ScalarType, dims: &[usize]) -> Shape { + Shape { + data_type, + dims: dims.to_vec(), + } + } + + pub fn is_empty(&self) -> bool { + self.dims.is_empty() + } + + pub fn rank(&self) -> usize { + self.dims.len() + } + + pub fn element_count(&self) -> usize { + self.dims.iter().product() + } + + pub fn buffer_bytes_aligned(&self) -> usize { + // Round buffer sizes to 16 bytes. If not, things go wrong (i.e. our shaders use vec4 - if a buffer only + // has 7 elements, the last vec4 cannot be fully written to the buffer, and the buffer ends up containing zeroes. + fn round_to_next_multiple_of_16(n: usize) -> usize { + (n + 15) / 16 * 16 + } + + round_to_next_multiple_of_16(self.element_count() * self.data_type.stride()) + } + + pub fn dim(&self, idx: usize) -> usize { + self.dims[idx] + } + + pub fn chunks(&self) -> Vec { + let mut chunk = vec![]; + let ds = &self.dims; + for i in 1..self.dims.len() { + chunk.push(ds[i..].iter().product()); + } + chunk.push(1); + chunk + } + + /// Computes the shape to which all provided shapes can be broadcast (if it exists) + /// Inspired by + pub fn multi_broadcast(shapes: &[Shape]) -> Option { + if shapes.is_empty() { + return None; + } + + let max_rank = shapes.iter().map(|x| x.rank()).max().unwrap_or(0); + let mut shape: Vec = Vec::with_capacity(max_rank); + + // Shapes must all have the same data type + let data_type = shapes[0].data_type; + for s in shapes { + if s.data_type != data_type { + return None; + } + } + + for i in 0..max_rank { + let mut wanted_size = 1; + for shape in shapes { + let rank = shape.rank(); + let dim = if i < rank { shape.dim(rank - i - 1) } else { 1 }; + + if dim != 1 { + if wanted_size != 1 && dim != wanted_size { + return None; + } + wanted_size = dim; + } + } + shape.push(wanted_size); + } + + shape.reverse(); + Some(Shape::from(data_type, &shape)) + } + + pub(crate) fn left_padded_to(&self, x: usize, rank: usize) -> Shape { + let mut dims = self.dims.clone(); + let current_rank = dims.len(); + dims.resize(rank, x); + if rank > current_rank { + dims.rotate_right(rank - current_rank); + } + Shape { + dims, + data_type: self.data_type, + } + } +} + +#[derive(Clone, Serialize, Debug, PartialEq)] +#[serde(untagged)] +pub enum TensorData<'a> { + F32(Cow<'a, [f32]>), + I32(Cow<'a, [i32]>), + I64(Cow<'a, [i64]>), + U8(Cow<'a, [u8]>), +} + +impl<'a> TensorData<'a> { + pub fn into_static(self) -> TensorData<'static> { + match self { + TensorData::F32(x) => TensorData::F32(Cow::Owned(x.into_owned())), + TensorData::I32(x) => TensorData::I32(Cow::Owned(x.into_owned())), + TensorData::I64(x) => TensorData::I64(Cow::Owned(x.into_owned())), + TensorData::U8(x) => TensorData::U8(Cow::Owned(x.into_owned())), + } + } + + pub fn scalar_type(&self) -> ScalarType { + match self { + TensorData::F32(_) => ScalarType::F32, + TensorData::I32(_) => ScalarType::I32, + TensorData::I64(_) => ScalarType::I64, + TensorData::U8(_) => ScalarType::U8, + } + } +} + +impl<'a> From> for TensorData<'a> { + fn from(value: Vec) -> Self { + TensorData::F32(Cow::Owned(value)) + } +} + +impl<'a> From<&'a [f32]> for TensorData<'a> { + fn from(a: &'a [f32]) -> Self { + TensorData::F32(Cow::Borrowed(a)) + } +} + +impl<'a> From<&'a [i32]> for TensorData<'a> { + fn from(a: &'a [i32]) -> Self { + TensorData::I32(Cow::Borrowed(a)) + } +} + +impl<'a> From<&'a [i64]> for TensorData<'a> { + fn from(a: &'a [i64]) -> Self { + TensorData::I64(Cow::Borrowed(a)) + } +} + +#[derive(Error, Debug)] +pub enum TensorConversionError { + #[error("could not convert to the requested type becaue a value could not be represented in the target type")] + OutOfBoundsError, + + #[error("cold not return the requested type; conversions cannot be done for slices")] + DataTypeError, +} + +impl<'a> TryFrom> for Vec { + type Error = TensorConversionError; + + /// Convert OutputTensor into a `Vec`, possibly converting integer tensors if the values fit + fn try_from(value: TensorData) -> Result { + match value { + TensorData::F32(floats) => Ok(floats.to_vec()), + TensorData::I32(ints) => ints + .iter() + .map(|i| f32::from_i32(*i).ok_or(TensorConversionError::OutOfBoundsError)) + .collect::>(), + TensorData::I64(ints) => ints + .iter() + .map(|i| f32::from_i64(*i).ok_or(TensorConversionError::OutOfBoundsError)) + .collect::>(), + TensorData::U8(ints) => ints + .iter() + .map(|i| f32::from_u8(*i).ok_or(TensorConversionError::OutOfBoundsError)) + .collect::>(), + } + } +} + +/// Convert &OutputTensor into an &[f32]. Because we cannot store converted results, this operation does not attempt +/// to convert the tensor if the values are of a different type +impl<'a> TryFrom<&'a TensorData<'a>> for &'a [f32] { + type Error = TensorConversionError; + + fn try_from(value: &'a TensorData) -> Result { + match value { + TensorData::F32(floats) => Ok(floats), + TensorData::I32(_) | TensorData::I64(_) | TensorData::U8(_) => { + Err(TensorConversionError::DataTypeError) + } + } + } +} + +#[derive(Error, Debug)] +pub enum DataTypeError { + #[error("the ONNX scalar data type {0:?} is not supported")] + NotSupported(i32), + + #[error("the ONNX data type '{0}' is not recognized")] + NotRecognized(i32), + + #[error("encountered parametrized dimensions '{0}'; this is not currently supported (this may be solved by running onnx-simplifier on the model first)")] + ParametrizedDimensionUnsupported(String), + + #[error("type is undefined")] + Undefined, +} + +/// Data type for a single number +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ScalarType { + F32, + I64, + I32, + U8, +} + +impl ScalarType { + pub fn stride(&self) -> usize { + match self { + ScalarType::F32 => 4, + ScalarType::I32 => 4, + ScalarType::I64 => 8, + ScalarType::U8 => 1, // ! TODO check this + } + } + + pub fn wgsl_supported(&self) -> bool { + match self { + ScalarType::F32 => true, + ScalarType::I32 => true, + ScalarType::I64 => false, + ScalarType::U8 => false, // ! TODO check this + } + } + + pub fn wgsl_type_name(&self) -> &'static str { + match self { + ScalarType::F32 => "f32", + ScalarType::I32 => "i32", + ScalarType::I64 => "i64", + ScalarType::U8 => "u8", // ! TODO check this + } + } + + pub fn is_float(&self) -> bool { + match self { + ScalarType::F32 => true, + ScalarType::I32 | ScalarType::I64 | ScalarType::U8 => false, + } + } +} + +impl Display for ScalarType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.wgsl_type_name()) + } +} + +/// Represents a WGSL data type that can be used in a shader to perform an operation on multiple scalars at once. The +/// larger the data type, the more efficiently the GPU can perform operations. However, large data types require the size +/// of the data that is being worked on to be a multiple of the data type (e.g. a vec4 can be used to work on a vector of +/// 256 elements, but when used on a vector of 255 elements calculation would overflow if the shader doesn't take this +/// into account). The WGSL declaration looks like: +/// +/// struct Block { +/// data: [[stride( dt.size_bytes() )]] dt.wgsl_type_name(); +/// }; +pub(crate) enum MultiType { + Scalar(ScalarType), + Vec(ScalarType, usize), + Mat(ScalarType, usize, usize), +} + +impl MultiType { + /// Determine the appropriate data type given the data size + pub fn for_size(n: usize, scalar: ScalarType) -> MultiType { + let d = num::integer::gcd(n, 4); + match d { + 1 => MultiType::Scalar(scalar), + 2 | 4 => MultiType::Vec(scalar, d), + /* 3 can't occur here because it is not a divisor of 4. Even so, we wouldn't be able to use vec3, because + its stride is 16 instead of the expected 12, which would require padding to work properly. */ + _ => unreachable!(), + } + } + + /// Size (in bytes) of the data type (useful for setting the 'stride' in WGSL) + pub fn stride(&self) -> usize { + match self { + MultiType::Scalar(s) => s.stride(), + + // FIXME: this may not always be right! + MultiType::Vec(st, n) => st.stride() * n, + MultiType::Mat(st, w, h) => st.stride() * w * h, + } + } + + /// Name of this data type in WGSL + pub fn wgsl_type_name(&self) -> String { + match self { + MultiType::Scalar(s) => s.wgsl_type_name().to_string(), + MultiType::Vec(st, n) => format!("vec{}<{}>", n, st.wgsl_type_name()), + MultiType::Mat(st, w, h) => format!("mat{}x{}<{}>", w, h, st.wgsl_type_name()), + } + } + + /// The number of elements in this data type + pub fn elements(&self) -> usize { + match self { + MultiType::Scalar(_) => 1, + + // FIXME: this may not always be right + MultiType::Vec(_, n) => *n, + &MultiType::Mat(_, w, h) => w * h, + } + } +} + +impl Display for Shape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}:{}", + self.dims + .iter() + .map(|x| x.to_string()) + .collect::>() + .join("x"), + self.data_type + ) + } +} + +/// Divide a number by the indicated dividend, then round up to the next multiple of the dividend if there is a rest. +pub(crate) fn ceil(num: usize, div: usize) -> usize { + num / div + (num % div != 0) as usize +} + +#[cfg(test)] +mod tests { + use super::{ScalarType, Shape}; + + // Test cases for Shape::multi_broadcast, some inspired by + #[test] + pub fn test_multi_broadcast() { + fn shape(s: &[usize]) -> Shape { + Shape::from(ScalarType::F32, s) + } + + assert_eq!( + Shape::multi_broadcast(&[shape(&[2, 3, 4, 5]), shape(&[])]), + Some(shape(&[2, 3, 4, 5])), + ); + + assert_eq!( + Shape::multi_broadcast(&[shape(&[2, 3, 4, 5]), shape(&[5])]), + Some(shape(&[2, 3, 4, 5])), + ); + + assert_eq!( + Shape::multi_broadcast(&[shape(&[2, 3, 4, 5]), shape(&[4, 5])]), + Some(shape(&[2, 3, 4, 5])), + ); + + assert_eq!( + Shape::multi_broadcast(&[shape(&[4, 5]), shape(&[2, 3, 4, 5])]), + Some(shape(&[2, 3, 4, 5])), + ); + + assert_eq!( + Shape::multi_broadcast(&[shape(&[1, 4, 5]), shape(&[2, 3, 4, 1])]), + Some(shape(&[2, 3, 4, 5])), + ); + + assert_eq!( + Shape::multi_broadcast(&[shape(&[3, 4, 5]), shape(&[2, 1, 1, 1])]), + Some(shape(&[2, 3, 4, 5])), + ); + + assert_eq!( + Shape::multi_broadcast(&[shape(&[3, 4, 5]), shape(&[2, 4, 1, 1])]), + None + ); + + assert_eq!( + Shape::multi_broadcast(&[shape(&[1, 255, 768]), shape(&[1, 255, 1])]), + Some(shape(&[1, 255, 768])), + ); + } +} diff --git a/wonnx/src/utils.rs b/wonnx/src/utils.rs deleted file mode 100644 index 80bfc144..00000000 --- a/wonnx/src/utils.rs +++ /dev/null @@ -1,887 +0,0 @@ -//! Various utilities to deal with the ONNX format structure -use protobuf::ProtobufEnum; -use protobuf::RepeatedField; -use serde::Serialize; - -use crate::onnx; -use crate::onnx::ModelProto; -use crate::onnx::OperatorSetIdProto; -use crate::onnx::TensorProto; -use crate::onnx::TensorProto_DataType; -use crate::onnx::TensorShapeProto; -use crate::onnx::TensorShapeProto_Dimension; -use crate::onnx::TypeProto; -use crate::onnx::TypeProto_Tensor; -use crate::onnx::TypeProto_oneof_value; -use crate::onnx::ValueInfoProto; -use num::FromPrimitive; -use std::borrow::Cow; -use std::convert::From; -use std::convert::Into; -use std::convert::TryFrom; -use std::fmt::Display; -use std::str::from_utf8; -use thiserror::Error; - -/* Minimum size of a buffer you can create with wgpu. Creating buffers smaller than this leads to panic "Validation -* error: buffer binding size X is less than minimum 64" in Device::create_bind_group */ -pub(crate) const MINIMUM_BUFFER_SIZE_BYTES: u64 = 64; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Shape { - pub dims: Vec, - pub data_type: ScalarType, -} - -impl Shape { - pub fn from(data_type: ScalarType, dims: &[i64]) -> Shape { - Shape { - data_type, - dims: dims.iter().map(|x| *x as u64).collect(), - } - } - - pub fn is_empty(&self) -> bool { - self.dims.is_empty() - } - - pub fn rank(&self) -> usize { - self.dims.len() - } - - pub fn element_count(&self) -> u64 { - self.dims.iter().product() - } - - pub fn buffer_bytes_aligned(&self) -> usize { - // Round buffer sizes to 16 bytes. If not, things go wrong (i.e. our shaders use vec4 - if a buffer only - // has 7 elements, the last vec4 cannot be fully written to the buffer, and the buffer ends up containing zeroes. - fn round_to_next_multiple_of_16(n: usize) -> usize { - (n + 15) / 16 * 16 - } - - round_to_next_multiple_of_16((self.element_count() as usize) * self.data_type.stride()) - } - - pub fn dim(&self, idx: usize) -> u64 { - self.dims[idx] - } - - pub fn chunks(&self) -> Vec { - let mut chunk = vec![]; - let ds = &self.dims; - for i in 1..self.dims.len() { - chunk.push(ds[i..].iter().product::()); - } - chunk.push(1); - chunk - } - - /// Computes the shape to which all provided shapes can be broadcast (if it exists) - /// Inspired by - pub fn multi_broadcast(shapes: &[Shape]) -> Option { - if shapes.is_empty() { - return None; - } - - let max_rank = shapes.iter().map(|x| x.rank()).max().unwrap_or(0); - let mut shape: Vec = Vec::with_capacity(max_rank); - - // Shapes must all have the same data type - let data_type = shapes[0].data_type; - for s in shapes { - if s.data_type != data_type { - return None; - } - } - - for i in 0..max_rank { - let mut wanted_size = 1; - for shape in shapes { - let rank = shape.rank(); - let dim = if i < rank { shape.dim(rank - i - 1) } else { 1 }; - - if dim != 1 { - if wanted_size != 1 && dim != wanted_size { - return None; - } - wanted_size = dim; - } - } - shape.push(wanted_size as i64); - } - - shape.reverse(); - Some(Shape::from(data_type, &shape)) - } - - pub(crate) fn left_padded_to(&self, x: u64, rank: usize) -> Shape { - let mut dims = self.dims.clone(); - let current_rank = dims.len(); - dims.resize(rank, x); - if rank > current_rank { - dims.rotate_right(rank - current_rank); - } - Shape { - dims, - data_type: self.data_type, - } - } -} - -#[derive(Clone)] -pub enum InputTensor<'a> { - F32(Cow<'a, [f32]>), - I32(Cow<'a, [i32]>), - I64(Cow<'a, [i64]>), - U8(Cow<'a, [u8]>), -} - -impl<'a> From<&'a [f32]> for InputTensor<'a> { - fn from(a: &'a [f32]) -> Self { - InputTensor::F32(Cow::Borrowed(a)) - } -} - -impl<'a> From<&'a [i32]> for InputTensor<'a> { - fn from(a: &'a [i32]) -> Self { - InputTensor::I32(Cow::Borrowed(a)) - } -} - -impl<'a> From<&'a [i64]> for InputTensor<'a> { - fn from(a: &'a [i64]) -> Self { - InputTensor::I64(Cow::Borrowed(a)) - } -} - -impl<'a> TryFrom<&'a TensorProto> for InputTensor<'a> { - type Error = DataTypeError; - - fn try_from(value: &'a TensorProto) -> Result { - Ok(match ScalarType::from_i32(value.get_data_type())? { - ScalarType::F32 => InputTensor::F32(Cow::Borrowed(value.get_float_data())), - ScalarType::I64 => InputTensor::I64(Cow::Borrowed(value.get_int64_data())), - ScalarType::I32 => InputTensor::I32(Cow::Borrowed(value.get_int32_data())), - ScalarType::U8 => InputTensor::U8(Cow::Borrowed(value.get_raw_data())), - }) - } -} - -#[derive(Error, Debug)] -pub enum TensorConversionError { - #[error("could not convert to the requested type becaue a value could not be represented in the target type")] - OutOfBoundsError, - - #[error("cold not return the requested type; conversions cannot be done for slices")] - DataTypeError, -} - -#[derive(Clone, Debug, PartialEq, Serialize)] -#[serde(untagged)] -pub enum OutputTensor { - F32(Vec), - I32(Vec), - I64(Vec), - U8(Vec), -} - -impl TryFrom for Vec { - type Error = TensorConversionError; - - /// Convert OutputTensor into a `Vec`, possibly converting integer tensors if the values fit - fn try_from(value: OutputTensor) -> Result { - match value { - OutputTensor::F32(floats) => Ok(floats), - OutputTensor::I32(ints) => ints - .into_iter() - .map(|i| f32::from_i32(i).ok_or(TensorConversionError::OutOfBoundsError)) - .collect::>(), - OutputTensor::I64(ints) => ints - .into_iter() - .map(|i| f32::from_i64(i).ok_or(TensorConversionError::OutOfBoundsError)) - .collect::>(), - OutputTensor::U8(ints) => ints - .into_iter() - .map(|i| f32::from_u8(i).ok_or(TensorConversionError::OutOfBoundsError)) - .collect::>(), - } - } -} - -/// Convert &OutputTensor into an &[f32]. Because we cannot store converted results, this operation does not attempt -/// to convert the tensor if the values are of a different type -impl<'a> TryFrom<&'a OutputTensor> for &'a [f32] { - type Error = TensorConversionError; - - fn try_from(value: &'a OutputTensor) -> Result { - match value { - OutputTensor::F32(floats) => Ok(floats.as_slice()), - OutputTensor::I32(_) | OutputTensor::I64(_) | OutputTensor::U8(_) => { - Err(TensorConversionError::DataTypeError) - } - } - } -} - -impl<'a> From<&InputTensor<'a>> for OutputTensor { - fn from(input: &InputTensor<'a>) -> Self { - match input { - InputTensor::F32(fs) => OutputTensor::F32(fs.to_vec()), - InputTensor::I32(fs) => OutputTensor::I32(fs.to_vec()), - InputTensor::I64(fs) => OutputTensor::I64(fs.to_vec()), - InputTensor::U8(fs) => OutputTensor::U8(fs.to_vec()), - } - } -} - -impl TensorProto { - pub fn from(value: OutputTensor, dims: Vec) -> Self { - let mut tensor = TensorProto::new(); - match value { - OutputTensor::F32(v) => { - tensor.set_data_type(ScalarType::F32.to_datatype().value()); - tensor.set_float_data(v); - } - OutputTensor::I32(v) => { - tensor.set_data_type(ScalarType::I32.to_datatype().value()); - tensor.set_int32_data(v); - } - OutputTensor::I64(v) => { - tensor.set_data_type(ScalarType::I64.to_datatype().value()); - tensor.set_int64_data(v); - } - OutputTensor::U8(v) => { - tensor.set_data_type(ScalarType::U8.to_datatype().value()); - tensor.set_raw_data(v); - } - } - tensor.set_dims(dims); - tensor - } -} - -#[derive(Error, Debug)] -pub enum DataTypeError { - #[error("the ONNX scalar data type '{0:?}' is not supported")] - NotSupported(TensorProto_DataType), - - #[error("the ONNX data type '{0}' is not recognized")] - NotRecognized(i32), - - #[error("encountered parametrized dimensions '{0}'; this is not currently supported (this may be solved by running onnx-simplifier on the model first)")] - ParametrizedDimensionUnsupported(String), - - #[error("type is undefined")] - Undefined, -} - -/// Data type for a single number -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ScalarType { - F32, - I64, - I32, - U8, -} - -impl ScalarType { - pub fn from_i32(onnx: i32) -> Result { - let onnx_dt = - TensorProto_DataType::from_i32(onnx).ok_or(DataTypeError::NotRecognized(onnx))?; - Self::from(onnx_dt) - } - - pub fn from(onnx: TensorProto_DataType) -> Result { - Ok(match onnx { - TensorProto_DataType::FLOAT => ScalarType::F32, - TensorProto_DataType::INT64 => ScalarType::I64, - TensorProto_DataType::INT32 => ScalarType::I32, - TensorProto_DataType::UINT8 => ScalarType::U8, - _ => return Err(DataTypeError::NotSupported(onnx)), - }) - } - - pub fn to_datatype(&self) -> TensorProto_DataType { - match self { - ScalarType::F32 => TensorProto_DataType::FLOAT, - ScalarType::I64 => TensorProto_DataType::INT64, - ScalarType::I32 => TensorProto_DataType::INT32, - ScalarType::U8 => TensorProto_DataType::UINT8, - } - } - - pub fn stride(&self) -> usize { - match self { - ScalarType::F32 => 4, - ScalarType::I32 => 4, - ScalarType::I64 => 8, - ScalarType::U8 => 1, // ! TODO check this - } - } - - pub fn wgsl_supported(&self) -> bool { - match self { - ScalarType::F32 => true, - ScalarType::I32 => true, - ScalarType::I64 => false, - ScalarType::U8 => false, // ! TODO check this - } - } - - pub fn wgsl_type_name(&self) -> &'static str { - match self { - ScalarType::F32 => "f32", - ScalarType::I32 => "i32", - ScalarType::I64 => "i64", - ScalarType::U8 => "u8", // ! TODO check this - } - } - - pub fn is_float(&self) -> bool { - match self { - ScalarType::F32 => true, - ScalarType::I32 | ScalarType::I64 | ScalarType::U8 => false, - } - } -} - -impl Display for ScalarType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.wgsl_type_name()) - } -} - -/// Represents a WGSL data type that can be used in a shader to perform an operation on multiple scalars at once. The -/// larger the data type, the more efficiently the GPU can perform operations. However, large data types require the size -/// of the data that is being worked on to be a multiple of the data type (e.g. a vec4 can be used to work on a vector of -/// 256 elements, but when used on a vector of 255 elements calculation would overflow if the shader doesn't take this -/// into account). The WGSL declaration looks like: -/// -/// struct Block { -/// data: [[stride( dt.size_bytes() )]] dt.wgsl_type_name(); -/// }; -pub(crate) enum MultiType { - Scalar(ScalarType), - Vec(ScalarType, usize), - Mat(ScalarType, usize, usize), -} - -impl MultiType { - /// Determine the appropriate data type given the data size - pub fn for_size(n: usize, scalar: ScalarType) -> MultiType { - let d = num::integer::gcd(n, 4); - match d { - 1 => MultiType::Scalar(scalar), - 2 | 4 => MultiType::Vec(scalar, d), - /* 3 can't occur here because it is not a divisor of 4. Even so, we wouldn't be able to use vec3, because - its stride is 16 instead of the expected 12, which would require padding to work properly. */ - _ => unreachable!(), - } - } - - /// Size (in bytes) of the data type (useful for setting the 'stride' in WGSL) - pub fn stride(&self) -> usize { - match self { - MultiType::Scalar(s) => s.stride(), - - // FIXME: this may not always be right! - MultiType::Vec(st, n) => st.stride() * n, - MultiType::Mat(st, w, h) => st.stride() * w * h, - } - } - - /// Name of this data type in WGSL - pub fn wgsl_type_name(&self) -> String { - match self { - MultiType::Scalar(s) => s.wgsl_type_name().to_string(), - MultiType::Vec(st, n) => format!("vec{}<{}>", n, st.wgsl_type_name()), - MultiType::Mat(st, w, h) => format!("mat{}x{}<{}>", w, h, st.wgsl_type_name()), - } - } - - /// The number of elements in this data type - pub fn elements(&self) -> usize { - match self { - MultiType::Scalar(_) => 1, - - // FIXME: this may not always be right - MultiType::Vec(_, n) => *n, - &MultiType::Mat(_, w, h) => w * h, - } - } -} - -impl Display for Shape { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}:{}", - self.dims - .iter() - .map(|x| x.to_string()) - .collect::>() - .join("x"), - self.data_type - ) - } -} - -#[derive(Error, Debug)] -#[error("did not find attribute '{attribute}' for node '{node_name}'")] -pub struct AttributeNotFoundError { - attribute: String, - node_name: String, -} - -pub trait NodeAttributes { - fn has_attribute(&self, attribute_name: &str) -> bool; - fn get_attribute_value>( - &self, - attribute: &str, - default: Option, - ) -> Result; -} - -impl NodeAttributes for onnx::NodeProto { - fn has_attribute(&self, attribute_name: &str) -> bool { - self.get_attribute() - .iter() - .any(|attr| attr.get_name() == attribute_name) - } - - fn get_attribute_value>( - &self, - attribute: &str, - default: Option, - ) -> Result { - match ( - self.get_attribute() - .iter() - .find(|attr| attr.get_name() == attribute), - default, - ) { - (Some(attr), _) => Ok(attr.clone().into()), - (None, Some(default_attr)) => Ok(default_attr), - (None, None) => Err(AttributeNotFoundError { - attribute: attribute.to_string(), - node_name: self.get_name().to_string(), - }), - } - } -} - -/// Divide a number by the indicated dividend, then round up to the next multiple of the dividend if there is a rest. -pub(crate) fn ceil(num: u64, div: u64) -> u64 { - num / div + (num % div != 0) as u64 -} - -impl ValueInfoProto { - pub fn get_shape(&self) -> Result { - Ok(match &self.get_field_type().value { - Some(t) => match t { - onnx::TypeProto_oneof_value::tensor_type(tensor_proto) => Shape::from( - ScalarType::from_i32(tensor_proto.get_elem_type())?, - self.get_field_type() - .get_tensor_type() - .get_shape() - .get_dim() - .iter() - .map(|x| { - if x.has_dim_param() { - return Err(DataTypeError::ParametrizedDimensionUnsupported( - x.get_dim_param().to_string(), - )); - } - Ok(x.get_dim_value()) - }) - .collect::, DataTypeError>>()? - .as_slice(), - ), - onnx::TypeProto_oneof_value::sequence_type(_) => todo!(), - onnx::TypeProto_oneof_value::map_type(_) => todo!(), - onnx::TypeProto_oneof_value::optional_type(_) => todo!(), - onnx::TypeProto_oneof_value::sparse_tensor_type(_) => todo!(), - }, - None => return Err(DataTypeError::Undefined), - }) - } - - pub fn set_shape(&mut self, shape: &Shape) { - let mut tpt = TypeProto_Tensor::new(); - tpt.set_elem_type(shape.data_type.to_datatype().value()); - - let mut tsp = TensorShapeProto::new(); - tsp.dim.extend(shape.dims.iter().map(|x| { - let mut tspd = TensorShapeProto_Dimension::new(); - tspd.set_dim_value(*x as i64); - tspd - })); - tpt.set_shape(tsp); - - let mut tp = TypeProto::new(); - tp.value = Some(TypeProto_oneof_value::tensor_type(tpt)); - self.set_field_type(tp); - } -} - -/// Shorthand method to define an ONNX tensor with the specified name and shape (data type is f32) -pub fn tensor(name: &str, dimensions: &[i64]) -> onnx::ValueInfoProto { - tensor_of_type(name, dimensions, TensorProto_DataType::FLOAT) -} - -/// Shorthand method to define an ONNX tensor with the specified name, shape and data type -pub fn tensor_of_type( - name: &str, - dimensions: &[i64], - data_type: TensorProto_DataType, -) -> onnx::ValueInfoProto { - let mut dim_value = vec![]; - for dimension in dimensions { - let mut dim_channel = onnx::TensorShapeProto_Dimension::new(); - dim_channel.set_dim_value(*dimension); - dim_value.push(dim_channel); - } - - let mut shape_tensor_proto = onnx::TensorShapeProto::new(); - shape_tensor_proto.set_dim(protobuf::RepeatedField::from(dim_value)); - - let mut type_proto_tensor = onnx::TypeProto_Tensor::new(); - type_proto_tensor.set_elem_type(data_type.value()); - type_proto_tensor.set_shape(shape_tensor_proto); - - let mut type_proto = onnx::TypeProto::new(); - type_proto.set_tensor_type(type_proto_tensor); - - let mut tensor = onnx::ValueInfoProto::new(); - tensor.set_name(name.to_string()); - tensor.set_field_type(type_proto); - - tensor -} - -pub fn initializer(name: &str, data: Vec, dimensions: Vec) -> onnx::TensorProto { - let mut initializer = crate::onnx::TensorProto::new(); - assert_eq!( - dimensions.iter().cloned().product::() as usize, - data.len() - ); - initializer.set_dims(dimensions); - initializer.set_name(name.to_string()); - initializer.set_data_type(TensorProto_DataType::FLOAT.value()); - initializer.set_float_data(data); - initializer -} - -pub fn initializer_int64(name: &str, data: Vec, dimensions: Vec) -> onnx::TensorProto { - let mut initializer = crate::onnx::TensorProto::new(); - assert_eq!( - dimensions.iter().cloned().product::() as usize, - data.len() - ); - initializer.set_name(name.to_string()); - initializer.set_dims(dimensions); - initializer.set_data_type(TensorProto_DataType::INT64.value()); - initializer.set_int64_data(data); - initializer -} - -pub fn attribute(name: &str, inputs: impl Into) -> onnx::AttributeProto { - let mut attributes: onnx::AttributeProto = inputs.into(); - attributes.set_name(name.to_string()); - attributes -} - -pub fn node( - inputs: Vec<&str>, - outputs: Vec<&str>, - name: &str, - op_type: &str, - attributes: Vec, -) -> onnx::NodeProto { - let mut node = crate::onnx::NodeProto::new(); - - node.set_op_type(op_type.to_string()); - node.set_name(name.to_string()); - node.set_input(protobuf::RepeatedField::from( - inputs - .iter() - .map(|s| s.to_string()) - .collect::>(), - )); - node.set_output(protobuf::RepeatedField::from( - outputs - .iter() - .map(|s| s.to_string()) - .collect::>(), - )); - node.set_attribute(protobuf::RepeatedField::from(attributes)); - node -} - -pub fn graph( - inputs: Vec, - outputs: Vec, - mut infos: Vec, - initializers: Vec, - nodes: Vec, -) -> onnx::GraphProto { - let mut graph = onnx::GraphProto::new(); - graph.set_node(protobuf::RepeatedField::from(nodes)); - graph.set_input(protobuf::RepeatedField::from(inputs)); - graph.set_output(protobuf::RepeatedField::from(outputs)); - - // Auto-generate tensor information for initializers so users don't have to specify those - for i in &initializers { - infos.push(tensor_of_type( - i.get_name(), - i.get_dims(), - onnx::TensorProto_DataType::from_i32(i.get_data_type()).unwrap(), - )); - } - - graph.set_initializer(protobuf::RepeatedField::from(initializers)); - graph.set_value_info(protobuf::RepeatedField::from(infos)); - graph -} - -pub fn model_with_opset(graph: onnx::GraphProto, opset_version: i64) -> onnx::ModelProto { - let mut model = crate::onnx::ModelProto::new(); - let mut onnx_opset_import = OperatorSetIdProto::new(); - onnx_opset_import.set_domain("".to_string()); - onnx_opset_import.set_version(opset_version); - model.set_opset_import(RepeatedField::from_slice(&[onnx_opset_import])); - model.set_graph(graph); - model -} - -pub fn model(graph: onnx::GraphProto) -> onnx::ModelProto { - model_with_opset(graph, 13) -} - -impl From> for onnx::AttributeProto { - fn from(value: Vec) -> Self { - let mut attributes = crate::onnx::AttributeProto::new(); - attributes.set_ints(value); - attributes - } -} - -impl From> for onnx::AttributeProto { - fn from(value: Vec) -> Self { - let mut attributes = crate::onnx::AttributeProto::new(); - attributes.set_floats(value); - attributes - } -} - -impl From for onnx::AttributeProto { - fn from(value: f32) -> Self { - let mut attributes = crate::onnx::AttributeProto::new(); - attributes.set_f(value); - attributes - } -} - -impl From for onnx::AttributeProto { - fn from(value: i64) -> Self { - let mut attributes = crate::onnx::AttributeProto::new(); - attributes.set_i(value); - attributes - } -} - -impl From for onnx::AttributeProto { - fn from(value: String) -> Self { - let mut attributes = crate::onnx::AttributeProto::new(); - attributes.set_s(value.into_bytes()); - attributes - } -} - -impl From<&str> for onnx::AttributeProto { - fn from(value: &str) -> Self { - let mut attributes = crate::onnx::AttributeProto::new(); - attributes.set_s(value.to_string().into_bytes()); - attributes - } -} - -impl From for onnx::AttributeProto { - fn from(value: TensorProto) -> Self { - let mut attributes = crate::onnx::AttributeProto::new(); - attributes.set_t(value); - attributes - } -} - -impl From for Vec { - fn from(value: onnx::AttributeProto) -> Self { - value.get_ints().to_vec() - } -} - -impl From for TensorProto { - fn from(value: onnx::AttributeProto) -> Self { - value.get_t().clone() - } -} - -impl From for Vec { - fn from(value: onnx::AttributeProto) -> Self { - value.get_floats().to_vec() - } -} - -impl From for f32 { - fn from(value: onnx::AttributeProto) -> Self { - value.get_f() - } -} - -impl From for i64 { - fn from(value: onnx::AttributeProto) -> Self { - value.get_i() - } -} - -impl From for String { - fn from(value: onnx::AttributeProto) -> Self { - from_utf8(value.get_s()).unwrap().to_string() - } -} - -#[derive(Error, Debug)] -pub enum OpsetError { - #[error("more than one ONNX opset was specified: {0} and {1}")] - DuplicateOnnxOpset(i64, i64), - - #[error("the model references an unknown opset: '{0}'")] - UnknownOpset(String), -} - -pub fn get_opset_version(model: &ModelProto) -> Result, OpsetError> { - // Find the version of the ONNX operator set this model is using (this is useful because some operators' specifications change over time). - // Note, if any other op set than the ONNX operator set is referenced, we cannot run the model. - // See https://github.com/onnx/onnx/blob/master/docs/Versioning.md#operator-sets - let mut onnx_opset_version = None; - for opset_import in model.get_opset_import() { - match opset_import.get_domain() { - "" => { - // This is a reference to the ONNX specification op set - if let Some(onnx_version) = onnx_opset_version { - if opset_import.get_version() != onnx_version { - return Err(OpsetError::DuplicateOnnxOpset( - onnx_version, - opset_import.get_version(), - )); - } - } else { - onnx_opset_version = Some(opset_import.get_version()); - } - } - some_other_opset => { - return Err(OpsetError::UnknownOpset(some_other_opset.to_string())); - } - } - } - Ok(onnx_opset_version) -} - -#[cfg(test)] -mod tests { - use crate::utils::{ - attribute, graph, initializer, model, node, tensor, OutputTensor, ScalarType, Shape, - }; - - #[test] - fn test_model() { - // USER INPUT - - let n = 5; - let c = 1; - let mut input_data = std::collections::HashMap::new(); - - let data: Vec = (0..25).map(|x| x as f32).collect(); - input_data.insert("X".to_string(), data.as_slice().into()); - - // ONNX INPUTS - let shape = vec![1, c, n, n]; - let kernel_n = 3; - let m = 1; - let data_w: Vec = (0..m * c * kernel_n * kernel_n).map(|_| 1.0f32).collect(); - let conv_model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &[1, 1, 3, 3])], - vec![], - vec![initializer("W", data_w, vec![m, c, 3, 3])], - vec![node( - vec!["X", "W"], - vec!["Y"], - "conv", - "Conv", - vec![attribute("kernel_shape", vec![3, 3])], - )], - )); - - // LOGIC - - let session = pollster::block_on(crate::Session::from_model(conv_model)) - .expect("Session did not create"); - - let result = pollster::block_on(session.run(&input_data)).unwrap(); - - assert_eq!( - result["Y"], - OutputTensor::F32(vec![54., 63., 72., 99., 108., 117., 144., 153., 162.]) - ); - } - - // Test cases for Shape::multi_broadcast, some inspired by - #[test] - pub fn test_multi_broadcast() { - fn shape(s: &[i64]) -> Shape { - Shape::from(ScalarType::F32, s) - } - - assert_eq!( - Shape::multi_broadcast(&[shape(&[2, 3, 4, 5]), shape(&[])]), - Some(shape(&[2, 3, 4, 5])), - ); - - assert_eq!( - Shape::multi_broadcast(&[shape(&[2, 3, 4, 5]), shape(&[5])]), - Some(shape(&[2, 3, 4, 5])), - ); - - assert_eq!( - Shape::multi_broadcast(&[shape(&[2, 3, 4, 5]), shape(&[4, 5])]), - Some(shape(&[2, 3, 4, 5])), - ); - - assert_eq!( - Shape::multi_broadcast(&[shape(&[4, 5]), shape(&[2, 3, 4, 5])]), - Some(shape(&[2, 3, 4, 5])), - ); - - assert_eq!( - Shape::multi_broadcast(&[shape(&[1, 4, 5]), shape(&[2, 3, 4, 1])]), - Some(shape(&[2, 3, 4, 5])), - ); - - assert_eq!( - Shape::multi_broadcast(&[shape(&[3, 4, 5]), shape(&[2, 1, 1, 1])]), - Some(shape(&[2, 3, 4, 5])), - ); - - assert_eq!( - Shape::multi_broadcast(&[shape(&[3, 4, 5]), shape(&[2, 4, 1, 1])]), - None - ); - - assert_eq!( - Shape::multi_broadcast(&[shape(&[1, 255, 768]), shape(&[1, 255, 1])]), - Some(shape(&[1, 255, 768])), - ); - } -} diff --git a/wonnx/tests/arithmetic.rs b/wonnx/tests/arithmetic.rs index cdef5c0a..2fd2d8d3 100644 --- a/wonnx/tests/arithmetic.rs +++ b/wonnx/tests/arithmetic.rs @@ -2,10 +2,11 @@ use approx::assert_abs_diff_eq; use std::{collections::HashMap, convert::TryInto}; use wonnx::{ onnx::TensorProto_DataType, - utils::{ - graph, initializer, initializer_int64, model, node, tensor, tensor_of_type, InputTensor, - OutputTensor, + onnx_model::{ + onnx_graph, onnx_initializer, onnx_initializer_int64, onnx_model, onnx_node, onnx_tensor, + onnx_tensor_of_type, }, + tensor::TensorData, }; mod common; @@ -20,19 +21,19 @@ fn test_cos() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Cos -> Y - let model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &shape)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &shape)], vec![], vec![], - vec![node(vec!["X"], vec!["Y"], "cos", "Cos", vec![])], + vec![onnx_node(vec!["X"], vec!["Y"], "Cos", vec![])], )); let session = pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); let result = pollster::block_on(session.run(&input_data)).unwrap(); - assert_eq!(result["Y"], OutputTensor::F32(vec![1.0; 16])); + assert_eq!(result["Y"], TensorData::F32(vec![1.0; 16].into())); } #[test] @@ -46,12 +47,12 @@ fn test_reciprocal() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Reciprocal -> Y - let model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &shape)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &shape)], vec![], vec![], - vec![node(vec!["X"], vec!["Y"], "rec", "Reciprocal", vec![])], + vec![onnx_node(vec!["X"], vec!["Y"], "Reciprocal", vec![])], )); let session = @@ -72,22 +73,30 @@ fn test_integer() { let data = vec![21i32; n]; let shape = vec![n as i64]; - input_data.insert("X".to_string(), InputTensor::I32(data.as_slice().into())); + input_data.insert("X".to_string(), TensorData::I32(data.as_slice().into())); // Model: X -> Add -> Y - let model = model(graph( - vec![tensor_of_type("X", &shape, TensorProto_DataType::INT32)], - vec![tensor_of_type("Y", &shape, TensorProto_DataType::INT32)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor_of_type( + "X", + &shape, + TensorProto_DataType::INT32, + )], + vec![onnx_tensor_of_type( + "Y", + &shape, + TensorProto_DataType::INT32, + )], vec![], vec![], - vec![node(vec!["X", "X"], vec!["Y"], "add_ints", "Add", vec![])], + vec![onnx_node(vec!["X", "X"], vec!["Y"], "Add", vec![])], )); let session = pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); let result = pollster::block_on(session.run(&input_data)).unwrap(); - assert_eq!(result["Y"], OutputTensor::I32(vec![42; n])); + assert_eq!(result["Y"], TensorData::I32(vec![42; n].into())); } #[test] @@ -99,22 +108,22 @@ fn test_int64_initializers() { let sum: Vec = (0..n).map(|x| (x * 3) as i64).collect(); let dims = vec![n as i64]; - let model = model(graph( - vec![tensor_of_type("X", &dims, TensorProto_DataType::INT64)], - vec![tensor_of_type("Z", &dims, TensorProto_DataType::INT64)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor_of_type("X", &dims, TensorProto_DataType::INT64)], + vec![onnx_tensor_of_type("Z", &dims, TensorProto_DataType::INT64)], vec![], - vec![initializer_int64("Y", right, dims.clone())], - vec![node(vec!["X", "Y"], vec!["Z"], "adder", "Add", vec![])], + vec![onnx_initializer_int64("Y", right, dims.clone())], + vec![onnx_node(vec!["X", "Y"], vec!["Z"], "Add", vec![])], )); let session = pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); - let mut input_data: HashMap = HashMap::new(); + let mut input_data: HashMap = HashMap::new(); input_data.insert("X".to_string(), left.as_slice().into()); let result = pollster::block_on(session.run(&input_data)).unwrap(); - assert_eq!(result["Z"], OutputTensor::I64(sum)) + assert_eq!(result["Z"], TensorData::I64(sum.into())) } pub fn assert_eq_vector_weak(xs: &[f32], ys: &[f32]) { @@ -138,12 +147,12 @@ fn test_pow() { input_data.insert("Y".to_string(), y.as_slice().into()); // Model: X,Y -> Pow -> Z - let model = model(graph( - vec![tensor("X", &shape), tensor("Y", &shape)], - vec![tensor("Z", &shape)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape), onnx_tensor("Y", &shape)], + vec![onnx_tensor("Z", &shape)], vec![], vec![], - vec![node(vec!["X", "Y"], vec!["Z"], "pow", "Pow", vec![])], + vec![onnx_node(vec!["X", "Y"], vec!["Z"], "Pow", vec![])], )); let session = @@ -173,12 +182,12 @@ fn test_mul_broadcast() { input_data.insert("Y".to_string(), y.as_slice().into()); // Model: X,Y -> Mul -> Z - let model = model(graph( - vec![tensor("X", &shape_x), tensor("Y", &shape_y)], - vec![tensor("Z", &shape_z)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape_x), onnx_tensor("Y", &shape_y)], + vec![onnx_tensor("Z", &shape_z)], vec![], vec![], - vec![node(vec!["X", "Y"], vec!["Z"], "mul", "Mul", vec![])], + vec![onnx_node(vec!["X", "Y"], vec!["Z"], "Mul", vec![])], )); let session = @@ -203,12 +212,12 @@ fn test_prelu() { input_data.insert("X".to_string(), data.into()); input_data.insert("Y".to_string(), slope.into()); - let model = model(graph( - vec![tensor("X", data_shape), tensor("Y", slope_shape)], - vec![tensor("Z", data_shape)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", data_shape), onnx_tensor("Y", slope_shape)], + vec![onnx_tensor("Z", data_shape)], vec![], vec![], - vec![node(vec!["X", "Y"], vec!["Z"], "prelu", "PRelu", vec![])], + vec![onnx_node(vec!["X", "Y"], vec!["Z"], "PRelu", vec![])], )); let session = @@ -264,12 +273,12 @@ fn test_sign() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Cos -> Y - let model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &shape)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &shape)], vec![], vec![], - vec![node(vec!["X"], vec!["Y"], "sign", "Sign", vec![])], + vec![onnx_node(vec!["X"], vec!["Y"], "Sign", vec![])], )); let session = @@ -289,25 +298,24 @@ fn test_sign() { .collect(); let result = pollster::block_on(session.run(&input_data)).unwrap(); - assert_eq!(result["Y"], OutputTensor::F32(expected)); + assert_eq!(result["Y"], TensorData::F32(expected.into())); } #[test] fn test_clip() { // Model: X -> Clip -> Y let shape = vec![1, 1, 2, 2]; - let model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &shape)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &shape)], vec![], vec![ - initializer("min", vec![0.0], vec![]), - initializer("max", vec![1.0], vec![]), + onnx_initializer("min", vec![0.0], vec![]), + onnx_initializer("max", vec![1.0], vec![]), ], - vec![node( + vec![onnx_node( vec!["X", "min", "max"], vec!["Y"], - "clip", "Clip", vec![], )], @@ -315,12 +323,15 @@ fn test_clip() { let mut input_data = HashMap::new(); input_data.insert( "X".to_string(), - InputTensor::F32([-1.0, 0.0, 1.0, 2.0].as_slice().into()), + TensorData::F32([-1.0, 0.0, 1.0, 2.0].as_slice().into()), ); let session = pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); let result = pollster::block_on(session.run(&input_data)).unwrap(); - assert_eq!(result["Y"], OutputTensor::F32(vec![0.0, 0.0, 1.0, 1.0])); + assert_eq!( + result["Y"], + TensorData::F32(vec![0.0, 0.0, 1.0, 1.0].into()) + ); } diff --git a/wonnx/tests/batchnormalization.rs b/wonnx/tests/batchnormalization.rs index d79fddbb..acefc9b2 100644 --- a/wonnx/tests/batchnormalization.rs +++ b/wonnx/tests/batchnormalization.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, convert::TryInto}; -use wonnx::utils::{attribute, graph, initializer, model, node, tensor}; +use wonnx::onnx_model::{ + onnx_attribute, onnx_graph, onnx_initializer, onnx_model, onnx_node, onnx_tensor, +}; mod common; #[test] @@ -30,22 +32,21 @@ fn batch_normalization() { assert_eq!(b.len(), channels); assert_eq!(scale.len(), channels); - let bn_model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &shape)], + let bn_model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &shape)], vec![], vec![ - initializer("scale", scale, vec![channels as i64]), - initializer("B", b, vec![channels as i64]), - initializer("input_mean", mean, vec![channels as i64]), - initializer("input_var", var, vec![channels as i64]), + onnx_initializer("scale", scale, vec![channels as i64]), + onnx_initializer("B", b, vec![channels as i64]), + onnx_initializer("input_mean", mean, vec![channels as i64]), + onnx_initializer("input_var", var, vec![channels as i64]), ], - vec![node( + vec![onnx_node( vec!["X", "scale", "B", "input_mean", "input_var"], vec!["Y"], - "bn", "BatchNormalization", - vec![attribute("epsilon", 0.1)], + vec![onnx_attribute("epsilon", 0.1)], )], )); diff --git a/wonnx/tests/cast.rs b/wonnx/tests/cast.rs index d7948dc3..15f1620c 100644 --- a/wonnx/tests/cast.rs +++ b/wonnx/tests/cast.rs @@ -1,9 +1,14 @@ +#![cfg(test)] + use std::collections::HashMap; use protobuf::ProtobufEnum; use wonnx::{ onnx::TensorProto_DataType, - utils::{attribute, graph, model, node, tensor, tensor_of_type, OutputTensor}, + onnx_model::{ + onnx_attribute, onnx_graph, onnx_model, onnx_node, onnx_tensor, onnx_tensor_of_type, + }, + tensor::TensorData, }; #[test] @@ -17,17 +22,19 @@ fn test_cast() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Identity -> Y; Y==Z - let model = model(graph( - vec![tensor("X", &dims)], - vec![tensor_of_type("Y", &dims, TensorProto_DataType::INT32)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &dims)], + vec![onnx_tensor_of_type("Y", &dims, TensorProto_DataType::INT32)], vec![], vec![], - vec![node( + vec![onnx_node( vec!["X"], vec!["Y"], - "a", "Cast", - vec![attribute("to", TensorProto_DataType::INT32.value() as i64)], + vec![onnx_attribute( + "to", + TensorProto_DataType::INT32.value() as i64, + )], )], )); @@ -37,6 +44,6 @@ fn test_cast() { let result = pollster::block_on(session.run(&input_data)).unwrap(); assert_eq!( result["Y"], - OutputTensor::I32(vec![0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5]) + TensorData::I32(vec![0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5].into()) ); } diff --git a/wonnx/tests/conv.rs b/wonnx/tests/conv.rs index 584c2ad4..1dbbeeac 100644 --- a/wonnx/tests/conv.rs +++ b/wonnx/tests/conv.rs @@ -1,6 +1,9 @@ use std::collections::HashMap; use std::convert::TryInto; -use wonnx::utils::{attribute, graph, initializer, model, node, tensor, OutputTensor}; +use wonnx::onnx_model::{ + onnx_attribute, onnx_graph, onnx_initializer, onnx_model, onnx_node, onnx_tensor, +}; +use wonnx::tensor::TensorData; use wonnx::*; mod common; @@ -16,19 +19,18 @@ fn conv_pad() { let data_w: Vec = (0..2 * c * 3 * 3).map(|_| 1.0f32).collect(); - let conv_model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &[2, 2, n, n])], + let conv_model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &[2, 2, n, n])], vec![], - vec![initializer("W", data_w, vec![2, c, 3, 3])], - vec![node( + vec![onnx_initializer("W", data_w, vec![2, c, 3, 3])], + vec![onnx_node( vec!["X", "W"], vec!["Y"], - "conv", "Conv", vec![ - attribute("kernel_shape", vec![3, 3]), - attribute("auto_pad", "SAME_UPPER"), + onnx_attribute("kernel_shape", vec![3, 3]), + onnx_attribute("auto_pad", "SAME_UPPER"), ], )], )); @@ -38,16 +40,20 @@ fn conv_pad() { let result = pollster::block_on(session.run(&input_data)).unwrap(); assert_eq!( result["Y"], - OutputTensor::F32(vec![ - 12.0, 21.0, 27.0, 33.0, 24.0, 33.0, 54.0, 63.0, 72.0, 51.0, 63.0, 99.0, 108.0, 117.0, - 81.0, 93.0, 144.0, 153.0, 162.0, 111.0, 72.0, 111.0, 117.0, 123.0, 84.0, 12.0, 21.0, - 27.0, 33.0, 24.0, 33.0, 54.0, 63.0, 72.0, 51.0, 63.0, 99.0, 108.0, 117.0, 81.0, 93.0, - 144.0, 153.0, 162.0, 111.0, 72.0, 111.0, 117.0, 123.0, 84.0, 112.0, 171.0, 177.0, - 183.0, 124.0, 183.0, 279.0, 288.0, 297.0, 201.0, 213.0, 324.0, 333.0, 342.0, 231.0, - 243.0, 369.0, 378.0, 387.0, 261.0, 172.0, 261.0, 267.0, 273.0, 184.0, 112.0, 171.0, - 177.0, 183.0, 124.0, 183.0, 279.0, 288.0, 297.0, 201.0, 213.0, 324.0, 333.0, 342.0, - 231.0, 243.0, 369.0, 378.0, 387.0, 261.0, 172.0, 261.0, 267.0, 273.0, 184.0 - ]) + TensorData::F32( + vec![ + 12.0, 21.0, 27.0, 33.0, 24.0, 33.0, 54.0, 63.0, 72.0, 51.0, 63.0, 99.0, 108.0, + 117.0, 81.0, 93.0, 144.0, 153.0, 162.0, 111.0, 72.0, 111.0, 117.0, 123.0, 84.0, + 12.0, 21.0, 27.0, 33.0, 24.0, 33.0, 54.0, 63.0, 72.0, 51.0, 63.0, 99.0, 108.0, + 117.0, 81.0, 93.0, 144.0, 153.0, 162.0, 111.0, 72.0, 111.0, 117.0, 123.0, 84.0, + 112.0, 171.0, 177.0, 183.0, 124.0, 183.0, 279.0, 288.0, 297.0, 201.0, 213.0, 324.0, + 333.0, 342.0, 231.0, 243.0, 369.0, 378.0, 387.0, 261.0, 172.0, 261.0, 267.0, 273.0, + 184.0, 112.0, 171.0, 177.0, 183.0, 124.0, 183.0, 279.0, 288.0, 297.0, 201.0, 213.0, + 324.0, 333.0, 342.0, 231.0, 243.0, 369.0, 378.0, 387.0, 261.0, 172.0, 261.0, 267.0, + 273.0, 184.0 + ] + .into() + ) ); } @@ -64,17 +70,20 @@ fn conv_without_pad() { let kernel_n = 3; let m = 1; let data_w: Vec = (0..m * c * kernel_n * kernel_n).map(|_| 1.0f32).collect(); - let conv_model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &[1, 1, 3, 3])], + let conv_model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &[1, 1, 3, 3])], vec![], - vec![initializer("W", data_w, vec![m, c, kernel_n, kernel_n])], - vec![node( + vec![onnx_initializer( + "W", + data_w, + vec![m, c, kernel_n, kernel_n], + )], + vec![onnx_node( vec!["X", "W"], vec!["Y"], - "conv", "Conv", - vec![attribute("kernel_shape", vec![3, 3])], + vec![onnx_attribute("kernel_shape", vec![3, 3])], )], )); @@ -94,17 +103,19 @@ fn conv_group_simple() { let shape = vec![1, 2, 1, 1]; input_data.insert("X".to_string(), [1.0, 2.0][..].into()); - let conv_model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &shape)], + let conv_model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &shape)], vec![], - vec![initializer("W", vec![0.5, 2.0], vec![2, 1, 1, 1])], - vec![node( + vec![onnx_initializer("W", vec![0.5, 2.0], vec![2, 1, 1, 1])], + vec![onnx_node( vec!["X", "W"], vec!["Y"], - "conv", "Conv", - vec![attribute("kernel_shape", vec![1, 1]), attribute("group", 2)], + vec![ + onnx_attribute("kernel_shape", vec![1, 1]), + onnx_attribute("group", 2), + ], )], )); @@ -132,20 +143,23 @@ fn conv_stride() { // auto_pad.set_name("auto_pad".to_string()); // auto_pad.set_s("SAME_UPPER".to_string().into_bytes()); - let model = model(graph( - vec![tensor("X", &[1, c, 7, 5])], - vec![tensor("Y", &[1, 1, 4, 3])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1, c, 7, 5])], + vec![onnx_tensor("Y", &[1, 1, 4, 3])], vec![], - vec![initializer("W", data_w, vec![m, c, kernel_n, kernel_n])], - vec![node( + vec![onnx_initializer( + "W", + data_w, + vec![m, c, kernel_n, kernel_n], + )], + vec![onnx_node( vec!["X", "W"], vec!["Y"], - "conv", "Conv", vec![ - attribute("strides", vec![2, 2]), - attribute("pads", vec![1, 1, 1, 1]), - attribute("kernel_shape", vec![kernel_n, kernel_n]), + onnx_attribute("strides", vec![2, 2]), + onnx_attribute("pads", vec![1, 1, 1, 1]), + onnx_attribute("kernel_shape", vec![kernel_n, kernel_n]), ], )], )); @@ -177,20 +191,23 @@ fn conv_asymetric_stride() { let m = 1; let data_w: Vec = (0..m * c * kernel_n * kernel_n).map(|_| 1.0f32).collect(); - let model = model(graph( - vec![tensor("X", &[1, c, 7, 5])], - vec![tensor("Y", &[1, 1, 4, 2])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1, c, 7, 5])], + vec![onnx_tensor("Y", &[1, 1, 4, 2])], vec![], - vec![initializer("W", data_w, vec![m, c, kernel_n, kernel_n])], - vec![node( + vec![onnx_initializer( + "W", + data_w, + vec![m, c, kernel_n, kernel_n], + )], + vec![onnx_node( vec!["X", "W"], vec!["Y"], - "conv", "Conv", vec![ - attribute("strides", vec![2, 2]), - attribute("pads", vec![1, 0, 1, 0]), - attribute("kernel_shape", vec![kernel_n, kernel_n]), + onnx_attribute("strides", vec![2, 2]), + onnx_attribute("pads", vec![1, 0, 1, 0]), + onnx_attribute("kernel_shape", vec![kernel_n, kernel_n]), ], )], )); diff --git a/wonnx/tests/gather.rs b/wonnx/tests/gather.rs index a3cf7d00..cb0e90a7 100644 --- a/wonnx/tests/gather.rs +++ b/wonnx/tests/gather.rs @@ -1,5 +1,5 @@ use std::{collections::HashMap, convert::TryInto}; -use wonnx::utils::{attribute, graph, model, node, tensor}; +use wonnx::onnx_model::{onnx_attribute, onnx_graph, onnx_model, onnx_node, onnx_tensor}; mod common; fn assert_gather( @@ -17,17 +17,19 @@ fn assert_gather( input_data.insert("I".to_string(), indices.into()); // Model: (X, I) -> Gather -> Y - let bn_model = model(graph( - vec![tensor("X", data_shape), tensor("I", indices_shape)], - vec![tensor("Y", output_shape)], + let bn_model = onnx_model(onnx_graph( + vec![ + onnx_tensor("X", data_shape), + onnx_tensor("I", indices_shape), + ], + vec![onnx_tensor("Y", output_shape)], vec![], vec![], - vec![node( + vec![onnx_node( vec!["X", "I"], vec!["Y"], - "myGather", "Gather", - vec![attribute("axis", axis)], + vec![onnx_attribute("axis", axis)], )], )); diff --git a/wonnx/tests/globalaveragepool.rs b/wonnx/tests/globalaveragepool.rs index 1b7fb0ba..f92e082a 100644 --- a/wonnx/tests/globalaveragepool.rs +++ b/wonnx/tests/globalaveragepool.rs @@ -1,5 +1,5 @@ use std::{collections::HashMap, convert::TryInto}; -use wonnx::utils::{graph, model, node, tensor}; +use wonnx::onnx_model::{onnx_graph, onnx_model, onnx_node, onnx_tensor}; mod common; #[test] @@ -23,18 +23,12 @@ fn global_average_pool() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> GlobalAveragePool -> Y - let bn_model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("Y", &output_shape)], + let bn_model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("Y", &output_shape)], vec![], vec![], - vec![node( - vec!["X"], - vec!["Y"], - "gap", - "GlobalAveragePool", - vec![], - )], + vec![onnx_node(vec!["X"], vec!["Y"], "GlobalAveragePool", vec![])], )); let session = diff --git a/wonnx/tests/identity.rs b/wonnx/tests/identity.rs index 1b8f79cb..2c1bf989 100644 --- a/wonnx/tests/identity.rs +++ b/wonnx/tests/identity.rs @@ -1,5 +1,9 @@ use std::{collections::HashMap, convert::TryInto}; -use wonnx::utils::{graph, model, node, tensor, OutputTensor}; +use wonnx::{ + onnx_model::{onnx_graph, onnx_model, onnx_node, onnx_tensor}, + tensor::TensorData, +}; + mod common; #[test] @@ -12,12 +16,12 @@ fn test_identity() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Identity -> Y; Y==Z - let model = model(graph( - vec![tensor("X", &dims)], - vec![tensor("Y", &dims)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &dims)], + vec![onnx_tensor("Y", &dims)], vec![], vec![], - vec![node(vec!["X"], vec!["Y"], "a", "Identity", vec![])], + vec![onnx_node(vec!["X"], vec!["Y"], "Identity", vec![])], )); let session = @@ -37,14 +41,14 @@ fn test_double_identity() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Identity -> Y -> Identity -> Z. X==Z - let model = model(graph( - vec![tensor("X", &dims)], - vec![tensor("Z", &dims)], - vec![tensor("Y", &dims)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &dims)], + vec![onnx_tensor("Z", &dims)], + vec![onnx_tensor("Y", &dims)], vec![], vec![ - node(vec!["X"], vec!["Y"], "a", "Identity", vec![]), - node(vec!["Y"], vec!["Z"], "b", "Identity", vec![]), + onnx_node(vec!["X"], vec!["Y"], "Identity", vec![]), + onnx_node(vec!["Y"], vec!["Z"], "Identity", vec![]), ], )); @@ -66,15 +70,15 @@ fn test_buffer_readability() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Cos -> Y -> Flatten -> Z -> Flatten -> W - let model = model(graph( - vec![tensor("X", &shape)], - vec![tensor("W", &shape)], - vec![tensor("Y", &shape), tensor("Z", &shape)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &shape)], + vec![onnx_tensor("W", &shape)], + vec![onnx_tensor("Y", &shape), onnx_tensor("Z", &shape)], vec![], vec![ - node(vec!["X"], vec!["Y"], "cos", "Cos", vec![]), - node(vec!["Y"], vec!["Z"], "rs", "Reshape", vec![]), - node(vec!["Z"], vec!["W"], "rs", "Reshape", vec![]), + onnx_node(vec!["X"], vec!["Y"], "Cos", vec![]), + onnx_node(vec!["Y"], vec!["Z"], "Reshape", vec![]), + onnx_node(vec!["Z"], vec!["W"], "Reshape", vec![]), ], )); @@ -82,5 +86,5 @@ fn test_buffer_readability() { pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); let result = pollster::block_on(session.run(&input_data)).unwrap(); - assert_eq!(result["W"], OutputTensor::F32(vec![1.0; 16])); + assert_eq!(result["W"], TensorData::F32(vec![1.0; 16].into())); } diff --git a/wonnx/tests/matrix.rs b/wonnx/tests/matrix.rs index 84138891..9eb0d108 100644 --- a/wonnx/tests/matrix.rs +++ b/wonnx/tests/matrix.rs @@ -1,5 +1,8 @@ use std::{collections::HashMap, convert::TryInto}; -use wonnx::utils::{attribute, graph, initializer, initializer_int64, model, node, tensor}; +use wonnx::onnx_model::{ + onnx_attribute, onnx_graph, onnx_initializer, onnx_initializer_int64, onnx_model, onnx_node, + onnx_tensor, +}; mod common; #[test] @@ -26,12 +29,12 @@ fn test_matmul_square_matrix() { input_data.insert("B".to_string(), data_b.as_slice().unwrap().into()); let n = n as i64; - let model = model(graph( - vec![tensor("A", &[n, n]), tensor("B", &[n, n])], - vec![tensor("C", &[n, n])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("A", &[n, n]), onnx_tensor("B", &[n, n])], + vec![onnx_tensor("C", &[n, n])], vec![], vec![], - vec![node(vec!["A", "B"], vec!["C"], "MatMul", "MatMul", vec![])], + vec![onnx_node(vec!["A", "B"], vec!["C"], "MatMul", vec![])], )); let session = @@ -53,25 +56,23 @@ fn test_transpose_4d_perm(transpose_first: &[i64], transpose_second: &[i64]) { .collect(); // Model: X -> Transpose -> Y -> Transpose -> Z; X==Z - let model = model(graph( - vec![tensor("X", &x_dims)], - vec![tensor("Z", &x_dims)], - vec![tensor("Y", &intermediate_dims)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &x_dims)], + vec![onnx_tensor("Z", &x_dims)], + vec![onnx_tensor("Y", &intermediate_dims)], vec![], vec![ - node( + onnx_node( vec!["X"], vec!["Y"], "Transpose", - "Transpose", - vec![attribute("perm", transpose_first.to_vec())], + vec![onnx_attribute("perm", transpose_first.to_vec())], ), - node( + onnx_node( vec!["Y"], vec!["Z"], "Transpose", - "Transpose", - vec![attribute("perm", transpose_second.to_vec())], + vec![onnx_attribute("perm", transpose_second.to_vec())], ), ], )); @@ -114,17 +115,16 @@ fn test_transposes_4d_0312() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Transpose -> Y - let model = model(graph( - vec![tensor("X", &[1, 2, 3, 4])], - vec![tensor("Y", &[1, 4, 2, 3])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1, 2, 3, 4])], + vec![onnx_tensor("Y", &[1, 4, 2, 3])], vec![], vec![], - vec![node( + vec![onnx_node( vec!["X"], vec!["Y"], "Transpose", - "Transpose", - vec![attribute("perm", vec![0, 3, 1, 2])], + vec![onnx_attribute("perm", vec![0, 3, 1, 2])], )], )); @@ -149,14 +149,14 @@ fn test_two_transposes_default_4d() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Transpose -> Y -> Transpose -> Z; X==Z - let model = model(graph( - vec![tensor("X", &[2, 3, 4])], - vec![tensor("Z", &[2, 3, 4])], - vec![tensor("Y", &[4, 3, 2])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[2, 3, 4])], + vec![onnx_tensor("Z", &[2, 3, 4])], + vec![onnx_tensor("Y", &[4, 3, 2])], vec![], vec![ - node(vec!["X"], vec!["Y"], "Transpose", "Transpose", vec![]), - node(vec!["Y"], vec!["Z"], "Transpose", "Transpose", vec![]), + onnx_node(vec!["X"], vec!["Y"], "Transpose", vec![]), + onnx_node(vec!["Y"], vec!["Z"], "Transpose", vec![]), ], )); @@ -174,25 +174,23 @@ fn test_two_transposes() { input_data.insert("X".to_string(), data.as_slice().into()); // Model: X -> Transpose -> Y -> Transpose -> Z; X==Z - let model = model(graph( - vec![tensor("X", &[2, 3, 4])], - vec![tensor("Z", &[2, 3, 4])], - vec![tensor("Y", &[4, 3, 2])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[2, 3, 4])], + vec![onnx_tensor("Z", &[2, 3, 4])], + vec![onnx_tensor("Y", &[4, 3, 2])], vec![], vec![ - node( + onnx_node( vec!["X"], vec!["Y"], "Transpose", - "Transpose", - vec![attribute("perm", vec![2, 1, 0])], + vec![onnx_attribute("perm", vec![2, 1, 0])], ), - node( + onnx_node( vec!["Y"], vec!["Z"], "Transpose", - "Transpose", - vec![attribute("perm", vec![2, 1, 0])], + vec![onnx_attribute("perm", vec![2, 1, 0])], ), ], )); @@ -211,17 +209,16 @@ fn test_split() { let data = (1..=2 * 6).map(|x| x as f32).collect::>(); input_data.insert("X".to_string(), data.as_slice().into()); - let model = model(graph( - vec![tensor("X", &[2, 6])], - vec![tensor("Y", &[2, 3]), tensor("W", &[2, 3])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[2, 6])], + vec![onnx_tensor("Y", &[2, 3]), onnx_tensor("W", &[2, 3])], vec![], vec![], - vec![node( + vec![onnx_node( vec!["X"], vec!["Y", "W"], "Split", - "Split", - vec![attribute("axis", -1)], + vec![onnx_attribute("axis", -1)], )], )); @@ -246,12 +243,12 @@ fn test_pad_example() { ].to_vec(); input_data.insert("X".to_string(), data.as_slice().into()); - let model = model(graph( - vec![tensor("X", &[3, 2])], - vec![tensor("Y", &[3, 4])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[3, 2])], + vec![onnx_tensor("Y", &[3, 4])], vec![], - vec![initializer_int64("pads", vec![0, 2, 0, 0], vec![4])], - vec![node(vec!["X", "pads"], vec!["Y"], "Pad", "Pad", vec![])], + vec![onnx_initializer_int64("pads", vec![0, 2, 0, 0], vec![4])], + vec![onnx_node(vec!["X", "pads"], vec!["Y"], "Pad", vec![])], )); let session = @@ -276,18 +273,17 @@ fn test_pad_complex() { input_data.insert("X".to_string(), data.as_slice().into()); let kv = 0.5; - let model = model(graph( - vec![tensor("X", &[1, 3, 2])], - vec![tensor("Y", &[2, 4, 5])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1, 3, 2])], + vec![onnx_tensor("Y", &[2, 4, 5])], vec![], vec![], - vec![node( + vec![onnx_node( vec!["X"], vec!["Y"], "Pad", - "Pad", vec![ - attribute( + onnx_attribute( "pads", vec![ 0, // x1_begin @@ -298,7 +294,7 @@ fn test_pad_complex() { 1, // x3_end ], ), - attribute("constant_value", kv), + onnx_attribute("constant_value", kv), ], )], )); @@ -345,17 +341,16 @@ fn test_resize() { let data = (1..=2 * 4).map(|x| x as f32).collect::>(); input_data.insert("X".to_string(), data.as_slice().into()); - let downsampling_model = model(graph( - vec![tensor("X", &[1, 1, 2, 4])], - vec![tensor("Y", &[1, 1, 1, 2])], + let downsampling_model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1, 1, 2, 4])], + vec![onnx_tensor("Y", &[1, 1, 1, 2])], vec![], - vec![initializer("scales", vec![1., 1., 0.6, 0.6], vec![4])], - vec![node( + vec![onnx_initializer("scales", vec![1., 1., 0.6, 0.6], vec![4])], + vec![onnx_node( vec!["X", "" /* roi */, "scales"], vec!["Y"], "Resize", - "Resize", - vec![attribute("nearest_mode", "floor")], + vec![onnx_attribute("nearest_mode", "floor")], )], )); @@ -370,17 +365,16 @@ fn test_resize() { let data = (1..=4).map(|x| x as f32).collect::>(); input_data.insert("X".to_string(), data.as_slice().into()); - let upsampling_model = model(graph( - vec![tensor("X", &[1, 1, 2, 2])], - vec![tensor("Y", &[1, 1, 4, 6])], + let upsampling_model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[1, 1, 2, 2])], + vec![onnx_tensor("Y", &[1, 1, 4, 6])], vec![], - vec![initializer("scales", vec![1., 1., 2., 3.], vec![4])], - vec![node( + vec![onnx_initializer("scales", vec![1., 1., 2., 3.], vec![4])], + vec![onnx_node( vec!["X", "" /* roi */, "scales"], vec!["Y"], "Resize", - "Resize", - vec![attribute("nearest_mode", "floor")], + vec![onnx_attribute("nearest_mode", "floor")], )], )); @@ -415,12 +409,12 @@ fn test_matmul_square_matrix_small() { input_data.insert("B".to_string(), data_b.as_slice().unwrap().into()); let n = n as i64; - let model = model(graph( - vec![tensor("A", &[n, n]), tensor("B", &[n, n])], - vec![tensor("C", &[n, n])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("A", &[n, n]), onnx_tensor("B", &[n, n])], + vec![onnx_tensor("C", &[n, n])], vec![], vec![], - vec![node(vec!["A", "B"], vec!["C"], "MatMul", "MatMul", vec![])], + vec![onnx_node(vec!["A", "B"], vec!["C"], "MatMul", vec![])], )); let session = @@ -445,12 +439,12 @@ fn test_matmul_nonsquare_matrix_small() { input_data.insert("A".to_string(), a_data.as_slice().into()); input_data.insert("B".to_string(), b_data.as_slice().into()); - let model = model(graph( - vec![tensor("A", &[4, 4]), tensor("B", &[4, 2])], - vec![tensor("C", &[4, 2])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("A", &[4, 4]), onnx_tensor("B", &[4, 2])], + vec![onnx_tensor("C", &[4, 2])], vec![], vec![], - vec![node(vec!["A", "B"], vec!["C"], "MatMul", "MatMul", vec![])], + vec![onnx_node(vec!["A", "B"], vec!["C"], "MatMul", vec![])], )); let session = @@ -475,12 +469,15 @@ fn test_matmul_stacks_4d() { input_data.insert("A".to_string(), a_data.as_slice().into()); input_data.insert("B".to_string(), b_data.as_slice().into()); - let model = model(graph( - vec![tensor("A", &[1, 4, 2, 2]), tensor("B", &[1, 4, 2, 2])], - vec![tensor("C", &[1, 4, 2, 2])], + let model = onnx_model(onnx_graph( + vec![ + onnx_tensor("A", &[1, 4, 2, 2]), + onnx_tensor("B", &[1, 4, 2, 2]), + ], + vec![onnx_tensor("C", &[1, 4, 2, 2])], vec![], vec![], - vec![node(vec!["A", "B"], vec!["C"], "MatMul", "MatMul", vec![])], + vec![onnx_node(vec!["A", "B"], vec!["C"], "MatMul", vec![])], )); let session = @@ -510,12 +507,12 @@ fn test_matmul_stacks() { input_data.insert("A".to_string(), a_data.as_slice().into()); input_data.insert("B".to_string(), b_data.as_slice().into()); - let model = model(graph( - vec![tensor("A", &[4, 2, 2]), tensor("B", &[4, 2, 2])], - vec![tensor("C", &[4, 2, 2])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("A", &[4, 2, 2]), onnx_tensor("B", &[4, 2, 2])], + vec![onnx_tensor("C", &[4, 2, 2])], vec![], vec![], - vec![node(vec!["A", "B"], vec!["C"], "MatMul", "MatMul", vec![])], + vec![onnx_node(vec!["A", "B"], vec!["C"], "MatMul", vec![])], )); let session = @@ -546,12 +543,12 @@ fn test_matmul_1d() { input_data.insert("A".to_string(), a_data.as_slice().into()); input_data.insert("B".to_string(), b_data.as_slice().into()); - let model = model(graph( - vec![tensor("A", &[1, 4]), tensor("B", &[4, 2])], - vec![tensor("C", &[1, 2])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("A", &[1, 4]), onnx_tensor("B", &[4, 2])], + vec![onnx_tensor("C", &[1, 2])], vec![], vec![], - vec![node(vec!["A", "B"], vec!["C"], "MatMul", "MatMul", vec![])], + vec![onnx_node(vec!["A", "B"], vec!["C"], "MatMul", vec![])], )); let session = @@ -580,16 +577,16 @@ fn test_gemm_matrix_bias() { input_data.insert("B".to_string(), b_data.as_slice().into()); input_data.insert("C".to_string(), c_data.as_slice().into()); - let model = model(graph( + let model = onnx_model(onnx_graph( vec![ - tensor("A", &[4, 6]), - tensor("B", &[6, 4]), - tensor("C", &[4, 4]), + onnx_tensor("A", &[4, 6]), + onnx_tensor("B", &[6, 4]), + onnx_tensor("C", &[4, 4]), ], - vec![tensor("D", &[4, 4])], + vec![onnx_tensor("D", &[4, 4])], vec![], vec![], - vec![node(vec!["A", "B", "C"], vec!["D"], "Gemm", "Gemm", vec![])], + vec![onnx_node(vec!["A", "B", "C"], vec!["D"], "Gemm", vec![])], )); let session = @@ -621,16 +618,16 @@ fn test_gemm_broadcasting_bias() { input_data.insert("B".to_string(), b_data.as_slice().into()); input_data.insert("C".to_string(), c_data.as_slice().into()); - let model = model(graph( + let model = onnx_model(onnx_graph( vec![ - tensor("A", &[4, 6]), - tensor("B", &[6, 4]), - tensor("C", &[1, 4]), + onnx_tensor("A", &[4, 6]), + onnx_tensor("B", &[6, 4]), + onnx_tensor("C", &[1, 4]), ], - vec![tensor("D", &[4, 4])], + vec![onnx_tensor("D", &[4, 4])], vec![], vec![], - vec![node(vec!["A", "B", "C"], vec!["D"], "Gemm", "Gemm", vec![])], + vec![onnx_node(vec!["A", "B", "C"], vec!["D"], "Gemm", vec![])], )); let session = @@ -661,16 +658,16 @@ fn test_gemm_broadcasting_second_bias() { input_data.insert("B".to_string(), b_data.as_slice().into()); input_data.insert("C".to_string(), c_data.as_slice().into()); - let model = model(graph( + let model = onnx_model(onnx_graph( vec![ - tensor("A", &[4, 6]), - tensor("B", &[6, 4]), - tensor("C", &[4, 1]), + onnx_tensor("A", &[4, 6]), + onnx_tensor("B", &[6, 4]), + onnx_tensor("C", &[4, 1]), ], - vec![tensor("D", &[4, 4])], + vec![onnx_tensor("D", &[4, 4])], vec![], vec![], - vec![node(vec!["A", "B", "C"], vec!["D"], "Gemm", "Gemm", vec![])], + vec![onnx_node(vec!["A", "B", "C"], vec!["D"], "Gemm", vec![])], )); let session = @@ -701,16 +698,16 @@ fn test_gemm_scalar_bias() { input_data.insert("B".to_string(), b_data.as_slice().into()); input_data.insert("C".to_string(), c_data.as_slice().into()); - let model = model(graph( + let model = onnx_model(onnx_graph( vec![ - tensor("A", &[4, 6]), - tensor("B", &[6, 4]), - tensor("C", &[1]), + onnx_tensor("A", &[4, 6]), + onnx_tensor("B", &[6, 4]), + onnx_tensor("C", &[1]), ], - vec![tensor("D", &[4, 4])], + vec![onnx_tensor("D", &[4, 4])], vec![], vec![], - vec![node(vec!["A", "B", "C"], vec!["D"], "Gemm", "Gemm", vec![])], + vec![onnx_node(vec!["A", "B", "C"], vec!["D"], "Gemm", vec![])], )); let session = diff --git a/wonnx/tests/onehot.rs b/wonnx/tests/onehot.rs index 0ca20702..90250ec2 100644 --- a/wonnx/tests/onehot.rs +++ b/wonnx/tests/onehot.rs @@ -1,7 +1,8 @@ use std::{collections::HashMap, convert::TryInto}; use wonnx::{ onnx::AttributeProto, - utils::{attribute, graph, model, node, tensor, InputTensor}, + onnx_model::{onnx_attribute, onnx_graph, onnx_model, onnx_node, onnx_tensor}, + tensor::TensorData, }; mod common; @@ -14,7 +15,7 @@ fn test_onehot( output: &[f32], output_shape: &[i64], ) { - let mut input_data = HashMap::::new(); + let mut input_data = HashMap::::new(); let depth_tensor: &[i32] = &[depth]; input_data.insert("I".to_string(), indexes.into()); @@ -23,23 +24,22 @@ fn test_onehot( let mut attributes: Vec = vec![]; if let Some(axis) = axis { - attributes.push(attribute("axis", axis)) + attributes.push(onnx_attribute("axis", axis)) } // Model: I, D, V -> OneHot -> Y - let model = model(graph( + let model = onnx_model(onnx_graph( vec![ - tensor("I", indexes_shape), - tensor("D", &[]), - tensor("V", &[values.len() as i64]), + onnx_tensor("I", indexes_shape), + onnx_tensor("D", &[]), + onnx_tensor("V", &[values.len() as i64]), ], - vec![tensor("Y", output_shape)], + vec![onnx_tensor("Y", output_shape)], vec![], vec![], - vec![node( + vec![onnx_node( vec!["I", "D", "V"], vec!["Y"], - "oneHot", "OneHot", attributes, )], diff --git a/wonnx/tests/pretrained_models.rs b/wonnx/tests/pretrained_models.rs index 80a3991b..0d2cc4c7 100644 --- a/wonnx/tests/pretrained_models.rs +++ b/wonnx/tests/pretrained_models.rs @@ -3,7 +3,7 @@ use ndarray::s; use std::collections::HashMap; use std::convert::TryInto; use std::path::Path; -use wonnx::utils::InputTensor; +use wonnx::tensor::TensorData; mod common; #[test] @@ -19,7 +19,7 @@ fn test_relu() { common::assert_eq_vector((&result["y"]).try_into().unwrap(), &[0.0, 1.0]); } -fn infer_mnist(image: InputTensor) -> (usize, f32) { +fn infer_mnist(image: TensorData) -> (usize, f32) { let session = pollster::block_on(wonnx::Session::from_path("../data/models/opt-mnist.onnx")) .expect("Session did not create"); diff --git a/wonnx/tests/reduce.rs b/wonnx/tests/reduce.rs index e43308c2..ba130fc4 100644 --- a/wonnx/tests/reduce.rs +++ b/wonnx/tests/reduce.rs @@ -1,7 +1,9 @@ use std::{collections::HashMap, convert::TryInto}; use wonnx::{ onnx::AttributeProto, - utils::{attribute, graph, initializer_int64, model, node, tensor}, + onnx_model::{ + onnx_attribute, onnx_graph, onnx_initializer_int64, onnx_model, onnx_node, onnx_tensor, + }, }; mod common; @@ -20,18 +22,18 @@ fn test_reduce( #[allow(clippy::bool_to_int_with_if)] let mut attributes: Vec = - vec![attribute("keepdims", if keep_dims { 1 } else { 0 })]; + vec![onnx_attribute("keepdims", if keep_dims { 1 } else { 0 })]; if let Some(axes) = axes { - attributes.push(attribute("axes", axes)) + attributes.push(onnx_attribute("axes", axes)) } // Model: X -> ReduceMean -> Y - let model = model(graph( - vec![tensor("X", data_shape)], - vec![tensor("Y", output_shape)], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", data_shape)], + vec![onnx_tensor("Y", output_shape)], vec![], vec![], - vec![node(vec!["X"], vec!["Y"], "myReduce", op_name, attributes)], + vec![onnx_node(vec!["X"], vec!["Y"], op_name, attributes)], )); let session = @@ -255,18 +257,17 @@ fn test_reduce_sum_with_axes_as_input() { ]; input_data.insert("X".to_string(), data.into()); - let attributes: Vec = vec![attribute("keepdims", 1)]; + let attributes: Vec = vec![onnx_attribute("keepdims", 1)]; // Model: X -> ReduceMean -> Y - let model = model(graph( - vec![tensor("X", &[3, 2, 2])], - vec![tensor("Y", &[3, 2])], + let model = onnx_model(onnx_graph( + vec![onnx_tensor("X", &[3, 2, 2])], + vec![onnx_tensor("Y", &[3, 2])], vec![], - vec![initializer_int64("A", vec![-2], vec![1])], - vec![node( + vec![onnx_initializer_int64("A", vec![-2], vec![1])], + vec![onnx_node( vec!["X", "A"], vec!["Y"], - "myReduce", "ReduceSum", attributes, )], diff --git a/wonnx/tests/softmax.rs b/wonnx/tests/softmax.rs index 67122cae..dd95e421 100644 --- a/wonnx/tests/softmax.rs +++ b/wonnx/tests/softmax.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, convert::TryInto}; -use wonnx::utils::{attribute, graph, model_with_opset, node, tensor}; +use wonnx::onnx_model::{ + onnx_attribute, onnx_graph, onnx_model_with_opset, onnx_node, onnx_tensor, +}; mod common; fn softmax_with_axis(x: &[f32], x_dims: &[i64], axis: i64, expected_y: &[f32], opset_version: i64) { @@ -7,18 +9,17 @@ fn softmax_with_axis(x: &[f32], x_dims: &[i64], axis: i64, expected_y: &[f32], o input_data.insert("X".to_string(), x.into()); // Model: X -> SoftMax -> Y - let model = model_with_opset( - graph( - vec![tensor("X", x_dims)], - vec![tensor("Y", x_dims)], + let model = onnx_model_with_opset( + onnx_graph( + vec![onnx_tensor("X", x_dims)], + vec![onnx_tensor("Y", x_dims)], vec![], vec![], - vec![node( + vec![onnx_node( vec!["X"], vec!["Y"], - "a", "Softmax", - vec![attribute("axis", axis)], + vec![onnx_attribute("axis", axis)], )], ), opset_version,