This repository was archived by the owner on Feb 26, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodule.rs
114 lines (100 loc) · 3.35 KB
/
module.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//! Provides the [`Module`] type and methods for working with runtime TVM modules.
use std::{
ffi::CString,
mem,
os::raw::{c_char, c_int},
path::Path,
ptr,
};
use ts;
use function::Function;
use internal_api;
use ErrorKind;
use Result;
const ENTRY_FUNC: &'static str = "__tvm_main__";
/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
/// Also [`is_released`] shows whether the module is dropped or not.
///
/// [`entry_func`]:struct.Module.html#method.entry_func
/// [`is_released`]:struct.Module.html#method.is_released
#[derive(Debug, Clone)]
pub struct Module {
pub(crate) handle: ts::TVMModuleHandle,
is_released: bool,
pub(crate) entry: Option<Function>,
}
impl Module {
pub(crate) fn new(
handle: ts::TVMModuleHandle,
is_released: bool,
entry: Option<Function>,
) -> Self {
Self {
handle,
is_released,
entry,
}
}
/// Sets the entry function of a module.
pub fn entry_func(&mut self) {
if self.entry.is_none() {
self.entry = self.get_function(ENTRY_FUNC, false).ok();
}
}
/// Gets a function by name from a registered module.
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function> {
let name = CString::new(name)?;
let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle;
check_call!(ts::TVMModGetFunction(
self.handle,
name.as_ptr() as *const c_char,
query_import as c_int,
&mut fhandle as *mut _
));
if fhandle.is_null() {
bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?)))
} else {
mem::forget(name);
Ok(Function::new(fhandle, false, false))
}
}
/// Imports a dependent module such as `.ptx` for gpu.
pub fn import_module(&self, dependent_module: Module) {
check_call!(ts::TVMModImport(self.handle, dependent_module.handle))
}
/// Loads a module shared library from path.
pub fn load(path: &Path) -> Result<Module> {
let path = path.to_owned();
let path_str = path.to_str()?.to_owned();
let ext = path.extension()?.to_str()?.to_owned();
let func = internal_api::get_api("module._LoadFromFile".to_owned());
let ret = call_packed!(func, &path_str, &ext)?;
mem::forget(path);
Ok(ret.to_module())
}
/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
let func = internal_api::get_api("module._Enabled".to_owned());
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let ret = call_packed!(func, target).unwrap();
ret.to_int() != 0
}
/// Returns the underlying module handle.
pub fn handle(&self) -> ts::TVMModuleHandle {
self.handle
}
/// Returns true if the underlying module has been dropped and false otherwise.
pub fn is_released(&self) -> bool {
self.is_released
}
}
impl Drop for Module {
fn drop(&mut self) {
if !self.is_released {
check_call!(ts::TVMModFree(self.handle));
self.is_released = true;
}
}
}