-
Notifications
You must be signed in to change notification settings - Fork 16
/
onnx_environment.rs
93 lines (83 loc) · 2.72 KB
/
onnx_environment.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
//! This module defines the ONNX environment for the execution of ONNX models.
use once_cell::sync::Lazy;
use ort::{Environment, ExecutionProvider};
use std::sync::Arc;
// Compiles the ONNX module into the rust binary.
#[cfg(all(
target_os = "macos",
not(doc),
not(onnx_runtime_env_var_set),
not(onnx_statically_linked)
))]
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../libonnxruntime.dylib");
#[cfg(all(
any(target_os = "linux", target_os = "android"),
not(doc),
not(onnx_runtime_env_var_set),
not(onnx_statically_linked)
))]
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../libonnxruntime.so");
#[cfg(all(
target_os = "windows",
not(doc),
not(onnx_runtime_env_var_set),
not(onnx_statically_linked)
))]
pub static LIB_BYTES: &'static [u8] = include_bytes!("../../libonnxruntime.dll");
// Fallback for documentation and other targets
#[cfg(any(
doc,
onnx_runtime_env_var_set,
onnx_statically_linked,
not(any(
target_os = "macos",
target_os = "linux",
target_os = "android",
target_os = "windows"
))
))]
pub static LIB_BYTES: &'static [u8] = &[];
// the ONNX environment which loads the library
pub static ENVIRONMENT: Lazy<Arc<Environment>> = Lazy::new(|| {
if cfg!(onnx_statically_linked) {
return Arc::new(
Environment::builder()
.with_execution_providers([ExecutionProvider::CPU(Default::default())])
.build()
.unwrap(),
);
}
// if the "ONNXRUNTIME_LIB_PATH" is provided we do not need to compile the ONNX library, instead we just point to the library
// in the "ONNXRUNTIME_LIB_PATH" and load that.
match std::env::var("ONNXRUNTIME_LIB_PATH") {
Ok(path) => {
std::env::set_var("ORT_DYLIB_PATH", path);
return Arc::new(
Environment::builder()
.with_execution_providers([ExecutionProvider::CPU(Default::default())])
.build()
.unwrap(),
);
}
// if the "ONNXRUNTIME_LIB_PATH" is not provided we use the `LIB_BYTES` that is the ONNX library compiled into the binary.
// we write the `LIB_BYTES` to a temporary file and then load that file.
Err(_) => {
let current_dir = std::env::current_dir().unwrap();
let current_dir = current_dir.to_str().unwrap();
let write_dir = std::path::Path::new(current_dir).join("libonnxruntime.dylib");
#[cfg(any(not(doc), not(onnx_runtime_env_var_set)))]
let _ = std::fs::write(write_dir.clone(), LIB_BYTES);
std::env::set_var("ORT_DYLIB_PATH", write_dir.clone());
let environment = Arc::new(
Environment::builder()
.with_execution_providers([ExecutionProvider::CPU(Default::default())])
.build()
.unwrap(),
);
std::env::remove_var("ORT_DYLIB_PATH");
#[cfg(any(not(doc), not(onnx_runtime_env_var_set)))]
let _ = std::fs::remove_file(write_dir);
return environment;
}
}
});