Skip to content

Commit

Permalink
feat: extract sequences & maps. closes #30
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Dec 28, 2023
1 parent dd8dcae commit af97600
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 164 deletions.
32 changes: 16 additions & 16 deletions examples/model-info/examples/model-info.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
use std::{env, process};

use ort::{Session, TensorElementDataType, ValueType};
use ort::{Session, TensorElementType, ValueType};

fn display_element_type(t: TensorElementDataType) -> &'static str {
fn display_element_type(t: TensorElementType) -> &'static str {
match t {
TensorElementDataType::Bfloat16 => "bf16",
TensorElementDataType::Bool => "bool",
TensorElementDataType::Float16 => "f16",
TensorElementDataType::Float32 => "f32",
TensorElementDataType::Float64 => "f64",
TensorElementDataType::Int16 => "i16",
TensorElementDataType::Int32 => "i32",
TensorElementDataType::Int64 => "i64",
TensorElementDataType::Int8 => "i8",
TensorElementDataType::String => "str",
TensorElementDataType::Uint16 => "u16",
TensorElementDataType::Uint32 => "u32",
TensorElementDataType::Uint64 => "u64",
TensorElementDataType::Uint8 => "u8"
TensorElementType::Bfloat16 => "bf16",
TensorElementType::Bool => "bool",
TensorElementType::Float16 => "f16",
TensorElementType::Float32 => "f32",
TensorElementType::Float64 => "f64",
TensorElementType::Int16 => "i16",
TensorElementType::Int32 => "i32",
TensorElementType::Int64 => "i64",
TensorElementType::Int8 => "i8",
TensorElementType::String => "str",
TensorElementType::Uint16 => "u16",
TensorElementType::Uint32 => "u32",
TensorElementType::Uint64 => "u64",
TensorElementType::Uint8 => "u8"
}
}

Expand Down
22 changes: 17 additions & 5 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{convert::Infallible, io, path::PathBuf, string};

use thiserror::Error;

use super::{char_p_to_string, ortsys, tensor::TensorElementDataType};
use super::{char_p_to_string, ortsys, tensor::TensorElementType, ValueType};

/// Type alias for the Result type returned by ORT functions.
pub type Result<T, E = Error> = std::result::Result<T, E>;
Expand Down Expand Up @@ -113,7 +113,7 @@ pub enum Error {
DownloadError(#[from] FetchModelError),
/// Type of input data and the ONNX model do not match.
#[error("Data types do not match: expected {model:?}, got {input:?}")]
NonMatchingDataTypes { input: TensorElementDataType, model: TensorElementDataType },
NonMatchingDataTypes { input: TensorElementType, model: TensorElementType },
/// Dimensions of input data and the ONNX model do not match.
#[error("Dimensions do not match: {0:?}")]
NonMatchingDimensions(NonMatchingDimensionsError),
Expand Down Expand Up @@ -152,9 +152,9 @@ pub enum Error {
#[error("Data type mismatch: was {actual:?}, tried to convert to {requested:?}")]
DataTypeMismatch {
/// The actual type of the ort output
actual: TensorElementDataType,
actual: TensorElementType,
/// The type corresponding to the attempted conversion into a Rust type, not equal to `actual`
requested: TensorElementDataType
requested: TensorElementType
},
#[error("Error trying to load symbol `{symbol}` from dynamic library: {error}")]
DlLoad { symbol: &'static str, error: String },
Expand All @@ -181,7 +181,19 @@ pub enum Error {
#[error("Failed to clear IO binding: {0}")]
ClearBinding(ErrorInternal),
#[error("Error when retrieving session outputs from `IoBinding`: {0}")]
GetBoundOutputs(ErrorInternal)
GetBoundOutputs(ErrorInternal),
#[error("Cannot use `extract_sequence` on a value that is {0:?}")]
NotSequence(ValueType),
#[error("Cannot use `extract_map` on a value that is {0:?}")]
NotMap(ValueType),
#[error("Tried to extract a map with a key type of {expected:?}, but the map has key type {actual:?}")]
InvalidMapKeyType { expected: TensorElementType, actual: TensorElementType },
#[error("Tried to extract a map with a value type of {expected:?}, but the map has value type {actual:?}")]
InvalidMapValueType { expected: TensorElementType, actual: TensorElementType },
#[error("Error occurred while attempting to extract data from sequence value: {0}")]
ExtractSequence(ErrorInternal),
#[error("Error occurred while attempting to extract data from map value: {0}")]
ExtractMap(ErrorInternal)
}

impl From<Infallible> for Error {
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ pub use self::session::{InMemorySession, Session, SessionBuilder, SessionInputs,
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub use self::tensor::{ArrayExtensions, ArrayViewHolder, Tensor, TensorData};
pub use self::tensor::{ExtractTensorData, IntoTensorElementDataType, TensorElementDataType};
pub use self::value::{Value, ValueType};
pub use self::tensor::{ExtractTensorData, IntoTensorElementType, TensorElementType};
pub use self::value::{Value, ValueRef, ValueType};

#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
macro_rules! extern_system_fn {
Expand Down
66 changes: 4 additions & 62 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,68 +668,10 @@ fn close_lib_handle(handle: *mut std::os::raw::c_void) {
/// `SessionBuilder::with_model_from_file()` method.
mod dangerous {
use super::*;
use crate::ortfree;

unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> Result<ValueType> {
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![GetTensorElementType(info_ptr, &mut type_sys) -> Error::GetTensorElementType];
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
// This transmute should be safe since its value is read from GetTensorElementType, which we must trust
let mut num_dims = 0;
ortsys![GetDimensionsCount(info_ptr, &mut num_dims) -> Error::GetDimensionsCount];

let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![GetDimensions(info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions];

Ok(ValueType::Tensor {
ty: type_sys.into(),
dimensions: node_dims
})
}

unsafe fn extract_data_type_from_sequence_info(info_ptr: *const ort_sys::OrtSequenceTypeInfo) -> Result<ValueType> {
let mut element_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
ortsys![GetSequenceElementType(info_ptr, &mut element_type_info) -> Error::GetSequenceElementType];

let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN;
let status = ortsys![unsafe GetOnnxTypeFromTypeInfo(element_type_info, &mut ty)];
status_to_result(status).map_err(Error::GetOnnxTypeFromTypeInfo)?;

match ty {
ort_sys::ONNXType::ONNX_TYPE_TENSOR => {
let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(element_type_info, &mut info_ptr) -> Error::CastTypeInfoToTensorInfo; nonNull(info_ptr)];
let ty = unsafe { extract_data_type_from_tensor_info(info_ptr)? };
Ok(ValueType::Sequence(Box::new(ty)))
}
ort_sys::ONNXType::ONNX_TYPE_MAP => {
let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut();
ortsys![unsafe CastTypeInfoToMapTypeInfo(element_type_info, &mut info_ptr) -> Error::CastTypeInfoToMapTypeInfo; nonNull(info_ptr)];
let ty = unsafe { extract_data_type_from_map_info(info_ptr)? };
Ok(ValueType::Sequence(Box::new(ty)))
}
_ => unreachable!()
}
}

unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::OrtMapTypeInfo) -> Result<ValueType> {
let mut key_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![GetMapKeyType(info_ptr, &mut key_type_sys) -> Error::GetMapKeyType];
assert_ne!(key_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);

let mut value_type_info: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
ortsys![GetMapValueType(info_ptr, &mut value_type_info) -> Error::GetMapValueType];
let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr) -> Error::CastTypeInfoToTensorInfo; nonNull(value_info_ptr)];
let mut value_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![GetTensorElementType(value_info_ptr, &mut value_type_sys) -> Error::GetTensorElementType];
assert_ne!(value_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);

Ok(ValueType::Map {
key: key_type_sys.into(),
value: value_type_sys.into()
})
}
use crate::{
ortfree,
value::{extract_data_type_from_map_info, extract_data_type_from_sequence_info, extract_data_type_from_tensor_info}
};

pub(super) fn extract_inputs_count(session_ptr: *mut ort_sys::OrtSession) -> Result<usize> {
let f = api().SessionGetInputCount.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::{fmt::Debug, ptr};
#[cfg(feature = "ndarray")]
use ::ndarray::{ArrayView, IxDyn};

pub use self::types::{ExtractTensorData, IntoTensorElementDataType, TensorElementDataType, Utf8Data};
pub use self::types::{ExtractTensorData, IntoTensorElementType, TensorElementType, Utf8Data};
#[cfg(feature = "ndarray")]
pub use self::{ndarray::ArrayExtensions, types::TensorData};
use crate::ortsys;
Expand Down
84 changes: 42 additions & 42 deletions src/tensor/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{ortsys, Error, Result};

/// Enum mapping ONNX Runtime's supported tensor data types.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum TensorElementDataType {
pub enum TensorElementType {
/// 32-bit floating point number, equivalent to Rust's `f32`.
Float32,
/// Unsigned 8-bit integer, equivalent to Rust's `u8`.
Expand Down Expand Up @@ -47,67 +47,67 @@ pub enum TensorElementDataType {
Bfloat16
}

impl From<TensorElementDataType> for ort_sys::ONNXTensorElementDataType {
fn from(val: TensorElementDataType) -> Self {
impl From<TensorElementType> for ort_sys::ONNXTensorElementDataType {
fn from(val: TensorElementType) -> Self {
match val {
TensorElementDataType::Float32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
TensorElementDataType::Uint8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
TensorElementDataType::Int8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
TensorElementDataType::Uint16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
TensorElementDataType::Int16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
TensorElementDataType::Int32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
TensorElementDataType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
TensorElementDataType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
TensorElementDataType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
TensorElementType::Float32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
TensorElementType::Uint8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
TensorElementType::Int8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
TensorElementType::Uint16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
TensorElementType::Int16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
TensorElementType::Int32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
TensorElementType::Int64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
TensorElementType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
TensorElementType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
#[cfg(feature = "half")]
TensorElementDataType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
TensorElementDataType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
TensorElementDataType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
TensorElementDataType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
TensorElementType::Float16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
TensorElementType::Float64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
TensorElementType::Uint32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
TensorElementType::Uint64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
// TensorElementDataType::Complex64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64,
// TensorElementDataType::Complex128 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128,
#[cfg(feature = "half")]
TensorElementDataType::Bfloat16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
TensorElementType::Bfloat16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
}
}
}
impl From<ort_sys::ONNXTensorElementDataType> for TensorElementDataType {
impl From<ort_sys::ONNXTensorElementDataType> for TensorElementType {
fn from(val: ort_sys::ONNXTensorElementDataType) -> Self {
match val {
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => TensorElementDataType::Float32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => TensorElementDataType::Uint8,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => TensorElementDataType::Int8,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => TensorElementDataType::Uint16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => TensorElementDataType::Int16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => TensorElementDataType::Int32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => TensorElementDataType::Int64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => TensorElementDataType::String,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => TensorElementDataType::Bool,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => TensorElementType::Float32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => TensorElementType::Uint8,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => TensorElementType::Int8,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => TensorElementType::Uint16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => TensorElementType::Int16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => TensorElementType::Int32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => TensorElementType::Int64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => TensorElementType::String,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => TensorElementType::Bool,
#[cfg(feature = "half")]
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => TensorElementDataType::Float16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => TensorElementDataType::Float64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => TensorElementDataType::Uint32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => TensorElementDataType::Uint64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => TensorElementType::Float16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => TensorElementType::Float64,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => TensorElementType::Uint32,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => TensorElementType::Uint64,
// ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => TensorElementDataType::Complex64,
// ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => TensorElementDataType::Complex128,
#[cfg(feature = "half")]
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => TensorElementDataType::Bfloat16,
ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => TensorElementType::Bfloat16,
_ => panic!("Invalid ONNXTensorElementDataType value")
}
}
}

/// Trait used to map Rust types (for example `f32`) to ONNX tensor element data types (for example `Float`).
pub trait IntoTensorElementDataType {
pub trait IntoTensorElementType {
/// Returns the ONNX tensor element data type corresponding to the given Rust type.
fn into_tensor_element_data_type() -> TensorElementDataType;
fn into_tensor_element_type() -> TensorElementType;
}

macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl IntoTensorElementDataType for $type_ {
fn into_tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::$variant
impl IntoTensorElementType for $type_ {
fn into_tensor_element_type() -> TensorElementType {
TensorElementType::$variant
}
}
};
Expand Down Expand Up @@ -159,7 +159,7 @@ impl<'a> Utf8Data for &'a str {
/// Trait used to map ONNX Runtime types to Rust types.
pub trait ExtractTensorData: Sized + fmt::Debug + Clone {
/// The tensor element type that this type can extract from.
fn tensor_element_data_type() -> TensorElementDataType;
fn tensor_element_type() -> TensorElementType;

#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
Expand Down Expand Up @@ -198,8 +198,8 @@ pub enum TensorData<'t, T> {
macro_rules! impl_prim_type_from_ort_trait {
($type_: ty, $variant: ident) => {
impl ExtractTensorData for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::$variant
fn tensor_element_type() -> TensorElementType {
TensorElementType::$variant
}

#[cfg(feature = "ndarray")]
Expand Down Expand Up @@ -251,8 +251,8 @@ impl_prim_type_from_ort_trait!(i64, Int64);
impl_prim_type_from_ort_trait!(bool, Bool);

impl ExtractTensorData for String {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
fn tensor_element_type() -> TensorElementType {
TensorElementType::String
}

#[cfg(feature = "ndarray")]
Expand Down

0 comments on commit af97600

Please sign in to comment.