Skip to content

Commit

Permalink
fix: aarch64 build, closes #98
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Oct 5, 2023
1 parent 3e73acc commit 99c5d43
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
20 changes: 14 additions & 6 deletions src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ use std::{
sync::Arc
};

use crate::{memory::MemoryInfo, ortsys, sys, value::Value, OrtError, OrtResult, Session};
use crate::{
memory::MemoryInfo,
ortsys,
sys::{self, size_t},
value::Value,
OrtError, OrtResult, Session
};

#[derive(Debug)]
pub struct IoBinding<'s> {
Expand Down Expand Up @@ -53,21 +59,23 @@ impl<'s> IoBinding<'s> {
nonNull(names_ptr)
];
if count > 0 {
let lengths = unsafe { std::slice::from_raw_parts(lengths_ptr, count).to_vec() };
let output_names = unsafe { ManuallyDrop::new(String::from_raw_parts(names_ptr as *mut u8, lengths.iter().sum(), lengths.iter().sum())) };
let lengths = unsafe { std::slice::from_raw_parts(lengths_ptr, count as _).to_vec() };
let output_names = unsafe {
ManuallyDrop::new(String::from_raw_parts(names_ptr as *mut u8, lengths.iter().sum::<size_t>() as _, lengths.iter().sum::<size_t>() as _))
};
let mut output_names_chars = output_names.chars();

let output_names = lengths
.into_iter()
.map(|length| output_names_chars.by_ref().take(length).collect::<String>())
.map(|length| output_names_chars.by_ref().take(length as _).collect::<String>())
.collect::<Vec<_>>();

ortsys![unsafe AllocatorFree(self.session.allocator(), names_ptr as *mut c_void) -> OrtError::CreateIoBinding];

let mut output_values_ptr: *mut *mut sys::OrtValue = vec![ptr::null_mut(); count].as_mut_ptr();
let mut output_values_ptr: *mut *mut sys::OrtValue = vec![ptr::null_mut(); count as _].as_mut_ptr();
ortsys![unsafe GetBoundOutputValues(self.ptr, self.session.allocator(), &mut output_values_ptr, &mut count) -> OrtError::CreateIoBinding; nonNull(output_values_ptr)];

let output_values_ptr = unsafe { std::slice::from_raw_parts(output_values_ptr, count).to_vec() }
let output_values_ptr = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() }
.into_iter()
.map(|v| Value::from_raw(v, Arc::clone(&self.session.session_ptr)));

Expand Down
4 changes: 2 additions & 2 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ mod dangerous {
}

fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut usize) -> *mut sys::OrtStatus },
f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut size_t) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession
) -> OrtResult<usize> {
let mut num_nodes = 0;
Expand All @@ -794,7 +794,7 @@ mod dangerous {
(num_nodes != 0)
.then_some(())
.ok_or_else(|| OrtError::GetInOutCount(OrtApiError::Msg("No nodes in model".to_owned())))?;
Ok(num_nodes)
Ok(num_nodes as _)
}

fn extract_input_name(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: size_t) -> OrtResult<String> {
Expand Down
2 changes: 1 addition & 1 deletion src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ impl<'v> Value<'v> {
let mut len = 0;
ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> OrtError::GetTensorShapeElementCount];

let data = T::extract_data(shape, len, self.ptr())?;
let data = T::extract_data(shape, len as _, self.ptr())?;
Ok(OrtOwnedTensor { data })
}
};
Expand Down

0 comments on commit 99c5d43

Please sign in to comment.