Skip to content

Commit

Permalink
fix: hold operator handles in session
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Apr 3, 2024
1 parent 69c191d commit de3bca4
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
6 changes: 6 additions & 0 deletions src/operator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,9 @@ impl OperatorDomain {
Ok(self)
}
}

impl Drop for OperatorDomain {
fn drop(&mut self) {
ortsys![unsafe ReleaseCustomOpDomain(self.ptr.as_ptr())];
}
}
55 changes: 39 additions & 16 deletions src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::os::windows::ffi::OsStrExt;
#[cfg(feature = "fetch-models")]
use std::path::PathBuf;
use std::{
any::Any,
ffi::CString,
marker::PhantomData,
path::Path,
Expand Down Expand Up @@ -43,7 +44,7 @@ pub struct SessionBuilder {
pub(crate) session_options_ptr: NonNull<ort_sys::OrtSessionOptions>,
memory_info: Option<Rc<MemoryInfo>>,
#[cfg(feature = "operator-libraries")]
custom_runtime_handles: Vec<*mut std::os::raw::c_void>,
custom_runtime_handles: Vec<Arc<LibHandle>>,
operator_domains: Vec<Arc<OperatorDomain>>,
execution_providers: Vec<ExecutionProviderDispatch>
}
Expand All @@ -67,11 +68,6 @@ impl Clone for SessionBuilder {

impl Drop for SessionBuilder {
fn drop(&mut self) {
#[cfg(feature = "operator-libraries")]
for &handle in &self.custom_runtime_handles {
close_lib_handle(handle);
}

ortsys![unsafe ReleaseSessionOptions(self.session_options_ptr.as_ptr())];
}
}
Expand Down Expand Up @@ -218,18 +214,19 @@ impl SessionBuilder {

let status = ortsys![unsafe RegisterCustomOpsLibrary(self.session_options_ptr.as_ptr(), path_cstr.as_ptr(), &mut handle)];

let handle = LibHandle(handle);
// per RegisterCustomOpsLibrary docs, release handle if there was an error and the handle
// is non-null
if let Err(e) = status_to_result(status).map_err(Error::CreateSessionOptions) {
if !handle.is_null() {
// handle was written to, should release it
close_lib_handle(handle);
drop(handle);
}

return Err(e);
}

self.custom_runtime_handles.push(handle);
self.custom_runtime_handles.push(Arc::new(handle));

Ok(self)
}
Expand Down Expand Up @@ -295,7 +292,7 @@ impl SessionBuilder {
}

/// Loads an ONNX model from a file and builds the session.
pub fn commit_from_file<P>(self, model_filepath_ref: P) -> Result<Session>
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session>
where
P: AsRef<Path>
{
Expand Down Expand Up @@ -354,10 +351,16 @@ impl SessionBuilder {
.map(|i| dangerous::extract_output(session_ptr, &allocator, i))
.collect::<Result<Vec<Output>>>()?;

let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>);
#[cfg(feature = "operator-libraries")]
let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box<dyn Any>));
let extras: Vec<Box<dyn Any>> = extras.collect();

Ok(Session {
inner: Arc::new(SharedSessionInner {
session_ptr,
allocator,
_extras: extras,
_environment: Arc::clone(env)
}),
inputs,
Expand Down Expand Up @@ -389,7 +392,7 @@ impl SessionBuilder {
}

/// Load an ONNX graph from memory and commit the session.
pub fn commit_from_memory(self, model_bytes: &[u8]) -> Result<Session> {
pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result<Session> {
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();

let env = get_environment()?;
Expand Down Expand Up @@ -429,10 +432,16 @@ impl SessionBuilder {
.map(|i| dangerous::extract_output(session_ptr, &allocator, i))
.collect::<Result<Vec<Output>>>()?;

let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>);
#[cfg(feature = "operator-libraries")]
let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box<dyn Any>));
let extras: Vec<Box<dyn Any>> = extras.collect();

let session = Session {
inner: Arc::new(SharedSessionInner {
session_ptr,
allocator,
_extras: extras,
_environment: Arc::clone(env)
}),
inputs,
Expand Down Expand Up @@ -533,12 +542,26 @@ impl From<GraphOptimizationLevel> for ort_sys::GraphOptimizationLevel {
}
}

#[cfg(all(unix, feature = "operator-libraries"))]
fn close_lib_handle(handle: *mut std::os::raw::c_void) {
unsafe { libc::dlclose(handle) };
#[cfg(feature = "operator-libraries")]
struct LibHandle(*mut std::os::raw::c_void);

#[cfg(feature = "operator-libraries")]
impl LibHandle {
pub(self) fn is_null(&self) -> bool {
self.0.is_null()
}
}

#[cfg(all(windows, feature = "operator-libraries"))]
fn close_lib_handle(handle: *mut std::os::raw::c_void) {
unsafe { winapi::um::libloaderapi::FreeLibrary(handle as winapi::shared::minwindef::HINSTANCE) };
#[cfg(feature = "operator-libraries")]
impl Drop for LibHandle {
fn drop(&mut self) {
#[cfg(unix)]
unsafe {
libc::dlclose(self.0)
};
#[cfg(windows)]
unsafe {
winapi::um::libloaderapi::FreeLibrary(self.0 as winapi::shared::minwindef::HINSTANCE)
};
}
}
5 changes: 4 additions & 1 deletion src/session/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Contains the [`Session`] and [`SessionBuilder`] types for managing ONNX Runtime sessions and performing inference.

use std::{ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc};
use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc};

use super::{
char_p_to_string,
Expand Down Expand Up @@ -34,6 +34,9 @@ pub use self::{
pub struct SharedSessionInner {
pub(crate) session_ptr: NonNull<ort_sys::OrtSession>,
allocator: Allocator,
/// Additional things we may need to hold onto for the duration of this session, like [`crate::OperatorDomain`]s and
/// DLL handles for operator libraries.
_extras: Vec<Box<dyn Any>>,
_environment: Arc<Environment>
}

Expand Down

0 comments on commit de3bca4

Please sign in to comment.