Skip to content

Commit

Permalink
refactor: improve onnxruntime logging
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Dec 5, 2023
1 parent 8e47753 commit 534a42a
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 93 deletions.
20 changes: 11 additions & 9 deletions ort-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ fn fetch_file(source_url: &str) -> Vec<u8> {
let resp = ureq::get(source_url)
.timeout(std::time::Duration::from_secs(1800))
.call()
.unwrap_or_else(|err| panic!("[ort] failed to download {source_url}: {err:?}"));
.unwrap_or_else(|err| panic!("Failed to GET `{source_url}`: {err}"));

let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
let len = resp
.header("Content-Length")
.and_then(|s| s.parse::<usize>().ok())
.expect("Content-Length header should be present on archive response");
let mut reader = resp.into_reader();
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).unwrap();
reader
.read_to_end(&mut buffer)
.unwrap_or_else(|err| panic!("Failed to download from `{source_url}`: {err}"));
assert_eq!(buffer.len(), len);
buffer
}
Expand Down Expand Up @@ -48,7 +53,7 @@ fn extract_tgz(buf: &[u8], output: &Path) {
let buf: std::io::BufReader<&[u8]> = std::io::BufReader::new(buf);
let tar = flate2::read::GzDecoder::new(buf);
let mut archive = tar::Archive::new(tar);
archive.unpack(output).unwrap();
archive.unpack(output).expect("Failed to extract .tgz file");
}

#[cfg(feature = "copy-dylibs")]
Expand All @@ -59,13 +64,10 @@ fn copy_libraries(lib_dir: &Path, out_dir: &Path) {
#[cfg(windows)]
let mut copy_fallback = false;

let lib_files = std::fs::read_dir(lib_dir).unwrap();
let lib_files = std::fs::read_dir(lib_dir).unwrap_or_else(|_| panic!("Failed to read contents of `{}` (does it exist?)", lib_dir.display()));
for lib_file in lib_files.filter(|e| {
e.as_ref().ok().map_or(false, |e| {
e.file_type().map_or(false, |e| !e.is_dir())
&& [".dll", ".so", ".dylib"]
.into_iter()
.any(|v| e.path().into_os_string().into_string().unwrap().contains(v))
e.file_type().map_or(false, |e| !e.is_dir()) && [".dll", ".so", ".dylib"].into_iter().any(|v| e.path().to_string_lossy().contains(v))
})
}) {
let lib_file = lib_file.unwrap();
Expand Down
37 changes: 17 additions & 20 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tracing::debug;
use super::{
custom_logger,
error::{Error, Result},
ortsys, ExecutionProviderDispatch, LoggingLevel
ortsys, ExecutionProviderDispatch
};

static G_ENV: OnceLock<EnvironmentSingleton> = OnceLock::new();
Expand Down Expand Up @@ -48,7 +48,6 @@ pub struct EnvironmentGlobalThreadPoolOptions {
/// times, the last value will have precedence.
pub struct EnvironmentBuilder {
name: String,
log_level: LoggingLevel,
execution_providers: Vec<ExecutionProviderDispatch>,
global_thread_pool_options: Option<EnvironmentGlobalThreadPoolOptions>
}
Expand All @@ -57,7 +56,6 @@ impl Default for EnvironmentBuilder {
fn default() -> Self {
EnvironmentBuilder {
name: "default".to_string(),
log_level: LoggingLevel::Error,
execution_providers: vec![],
global_thread_pool_options: None
}
Expand All @@ -78,16 +76,6 @@ impl EnvironmentBuilder {
self
}

/// Configure the environment with a given log level
///
/// **NOTE**: Since ONNX can only define one environment per process, creating multiple environments using multiple
/// [`EnvironmentBuilder`]s will end up re-using the same environment internally; a new one will _not_ be created.
/// New parameters will be ignored.
pub fn with_log_level(mut self, log_level: LoggingLevel) -> EnvironmentBuilder {
self.log_level = log_level;
self
}

/// Configures a list of execution providers sessions created under this environment will use by default. Sessions
/// may override these via
/// [`SessionBuilder::with_execution_providers`](crate::SessionBuilder::with_execution_providers).
Expand Down Expand Up @@ -164,7 +152,14 @@ impl EnvironmentBuilder {
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment];
}

ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, self.log_level.into(), cname.as_ptr(), thread_options, &mut env_ptr) -> Error::CreateEnvironment; nonNull(env_ptr)];
ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
thread_options,
&mut env_ptr
) -> Error::CreateEnvironment; nonNull(env_ptr)];
ortsys![unsafe ReleaseThreadingOptions(thread_options)];
env_ptr
} else {
Expand All @@ -173,7 +168,13 @@ impl EnvironmentBuilder {
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap();
ortsys![unsafe CreateEnvWithCustomLogger(logging_function, logger_param, self.log_level.into(), cname.as_ptr(), &mut env_ptr) -> Error::CreateEnvironment; nonNull(env_ptr)];
ortsys![unsafe CreateEnvWithCustomLogger(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
&mut env_ptr
) -> Error::CreateEnvironment; nonNull(env_ptr)];
env_ptr
};
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
Expand Down Expand Up @@ -225,11 +226,7 @@ mod tests {
assert!(!is_env_initialized());
assert_eq!(env_ptr(), None);

EnvironmentBuilder::default()
.with_name("env_is_initialized")
.with_log_level(LoggingLevel::Warning)
.commit()
.unwrap();
EnvironmentBuilder::default().with_name("env_is_initialized").commit().unwrap();
assert!(is_env_initialized());
assert_ne!(env_ptr(), None);
}
Expand Down
78 changes: 21 additions & 57 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use std::{
};

use once_cell::sync::Lazy;
use tracing::warn;
use tracing::{warn, Level};

pub use self::environment::{init, EnvironmentBuilder};
#[cfg(feature = "fetch-models")]
Expand Down Expand Up @@ -92,7 +92,8 @@ pub(crate) static G_ORT_LIB: Lazy<Arc<Mutex<AtomicPtr<libloading::Library>>>> =
.join(&path);
if relative.exists() { relative } else { path }
};
let lib = libloading::Library::new(&absolute_path).unwrap_or_else(|e| panic!("could not load the library at `{}`: {e:?}", absolute_path.display()));
let lib = libloading::Library::new(&absolute_path)
.unwrap_or_else(|e| panic!("An error occurred while attempting to load the ONNX Runtime binary at `{}`: {e}", absolute_path.display()));
Arc::new(Mutex::new(AtomicPtr::new(Box::leak(Box::new(lib)) as *mut _)))
}
});
Expand All @@ -102,7 +103,7 @@ pub(crate) static G_ORT_API: Lazy<Arc<Mutex<AtomicPtr<ort_sys::OrtApi>>>> = Lazy
unsafe {
let dylib = *G_ORT_LIB
.lock()
.expect("failed to acquire ONNX Runtime dylib lock; another thread panicked?")
.expect("Failed to acquire global ONNX Runtime dylib lock; did another thread using ort panic?")
.get_mut();
let base_getter: libloading::Symbol<unsafe extern "C" fn() -> *const ort_sys::OrtApiBase> = (*dylib)
.get(b"OrtGetApiBase")
Expand All @@ -114,6 +115,8 @@ pub(crate) static G_ORT_API: Lazy<Arc<Mutex<AtomicPtr<ort_sys::OrtApi>>>> = Lazy
(*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`");
let version_string = get_version_string();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
tracing::info!("Using ONNX Runtime version '{version_string}'");

let lib_minor_version = version_string.split('.').nth(1).map(|x| x.parse::<u32>().unwrap_or(0)).unwrap_or(0);
match lib_minor_version.cmp(&16) {
std::cmp::Ordering::Less => panic!(
Expand Down Expand Up @@ -146,7 +149,9 @@ pub(crate) static G_ORT_API: Lazy<Arc<Mutex<AtomicPtr<ort_sys::OrtApi>>>> = Lazy
///
/// Panics if another thread panicked while holding the API lock, or if the ONNX Runtime API could not be initialized.
pub fn ort() -> ort_sys::OrtApi {
let mut api_ref = G_ORT_API.lock().expect("failed to acquire OrtApi lock; another thread panicked?");
let mut api_ref = G_ORT_API
.lock()
.expect("Failed to acquire global ONNX Runtime API lock; did another thread using ort panic?");
let api_ref_mut: &mut *mut ort_sys::OrtApi = api_ref.get_mut();
let api_ptr_mut: *mut ort_sys::OrtApi = *api_ref_mut;

Expand Down Expand Up @@ -237,71 +242,30 @@ impl<'a> From<&'a str> for CodeLocation<'a> {

extern_system_fn! {
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, category: *const c_char, log_id: *const c_char, code_location: *const c_char, message: *const c_char) {
use tracing::{span, Level, trace, debug, warn, info, error};

let log_level = match severity {
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => Level::DEBUG,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => Level::INFO,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => Level::WARN,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR
};

pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, category: *const c_char, _: *const c_char, code_location: *const c_char, message: *const c_char) {
assert_ne!(category, ptr::null());
let category = unsafe { CStr::from_ptr(category) };
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("unknown");
let code_location_str = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap();
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<invalid>");
assert_ne!(log_id, ptr::null());
let log_id = unsafe { CStr::from_ptr(log_id) };
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap();

let code_location = CodeLocation::from(code_location);
let span = span!(
let code_location = CodeLocation::from(code_location_str);
let span = tracing::span!(
Level::TRACE,
"ort",
category = category.to_str().unwrap_or("<unknown>"),
file = code_location.file,
line = code_location.line,
function = code_location.function,
log_id = log_id.to_str().unwrap_or("<unknown>")
function = code_location.function
);
let _enter = span.enter();

match log_level {
Level::TRACE => trace!("{}", message),
Level::DEBUG => debug!("{}", message),
Level::INFO => info!("{}", message),
Level::WARN => warn!("{}", message),
Level::ERROR => error!("{}", message)
}
}
}

/// The minimum logging level. Logs will be handled by the `tracing` crate.
#[derive(Debug)]
pub enum LoggingLevel {
/// Verbose logging level. This will log *a lot* of messages!
Verbose,
/// Info logging level.
Info,
/// Warning logging level. Recommended to receive potentially important warnings.
Warning,
/// Error logging level.
Error,
/// Fatal logging level.
Fatal
}

impl From<LoggingLevel> for ort_sys::OrtLoggingLevel {
fn from(logging_level: LoggingLevel) -> Self {
match logging_level {
LoggingLevel::Verbose => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
LoggingLevel::Info => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
LoggingLevel::Warning => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
LoggingLevel::Error => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR,
LoggingLevel::Fatal => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL
match severity {
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::TRACE, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::DEBUG, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::INFO, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, Level::WARN, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL=> tracing::event!(parent: &span, Level::ERROR, "{message}")
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/mnist.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::path::Path;

use image::{imageops::FilterType, ImageBuffer, Luma, Pixel};
use ort::{download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
use ort::{download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, GraphOptimizationLevel, Session, Tensor};
use test_log::test;

#[test]
fn mnist_5() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mnist_5.jpg";

ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
ort::init().with_name("integration_test").commit()?;

let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
Expand Down
4 changes: 2 additions & 2 deletions tests/squeezenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use std::{

use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb};
use ndarray::s;
use ort::{download::vision::ImageClassification, inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
use ort::{download::vision::ImageClassification, inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, Session, Tensor};
use test_log::test;

#[test]
fn squeezenet_mushroom() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";

ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
ort::init().with_name("integration_test").commit()?;

let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
Expand Down
6 changes: 3 additions & 3 deletions tests/upsample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::Path;

use image::RgbImage;
use ndarray::{Array, CowArray, Ix4};
use ort::{inputs, GraphOptimizationLevel, LoggingLevel, Session, Tensor};
use ort::{inputs, GraphOptimizationLevel, Session, Tensor};
use test_log::test;

fn load_input_image<P: AsRef<Path>>(name: P) -> RgbImage {
Expand Down Expand Up @@ -44,7 +44,7 @@ fn convert_image_to_cow_array(img: &RgbImage) -> CowArray<'_, f32, Ix4> {
fn upsample() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";

ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
ort::init().with_name("integration_test").commit()?;

let session_data =
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
Expand Down Expand Up @@ -85,7 +85,7 @@ fn upsample() -> ort::Result<()> {
fn upsample_with_ort_model() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";

ort::init().with_name("integration_test").with_log_level(LoggingLevel::Warning).commit()?;
ort::init().with_name("integration_test").commit()?;

let session_data =
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.ort")).expect("Could not open model from file");
Expand Down

0 comments on commit 534a42a

Please sign in to comment.