This repository was archived by the owner on Feb 26, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathlib.rs
100 lines (88 loc) · 2.48 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//! [TVM](https://github.com/dmlc/tvm) is a compiler stack for deep learning systems.
//!
//! This crate provides an idiomatic Rust API for TVM runtime frontend.
//!
//! One particular use case is that given optimized deep learning model artifacts,
//! already compiled with TVM, which include a shared library
//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
//! in Rust to create a graph runtime. Then, run the model for some inputs and get the
//! desired predictions all in Rust.
//!
//! Checkout the `examples` repository for more details.
#![crate_name = "tvm_frontend"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_unsafe)]
#![feature(try_from, try_trait, fn_traits, unboxed_closures, box_syntax)]
#[macro_use]
extern crate error_chain;
extern crate tvm_sys as ts;
#[macro_use]
extern crate lazy_static;
extern crate ndarray as rust_ndarray;
extern crate num_traits;
use std::{
ffi::{CStr, CString},
str,
};
// Macro to check the return call to TVM runtime shared library
macro_rules! check_call {
($e:expr) => {{
if unsafe { $e } != 0 {
panic!("{}", $crate::get_last_error());
}
}};
}
/// Gets the last error message
pub fn get_last_error() -> &'static str {
unsafe {
match CStr::from_ptr(ts::TVMGetLastError()).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
}
}
}
pub(crate) fn set_last_error(err: &Error) {
let c_string = CString::new(err.to_string()).unwrap();
unsafe {
ts::TVMAPISetLastError(c_string.as_ptr());
}
}
#[macro_use]
pub mod function;
pub mod bytearray;
pub mod context;
pub mod errors;
mod internal_api;
pub mod module;
pub mod ndarray;
pub mod ty;
pub mod value;
pub use bytearray::TVMByteArray;
pub use context::{TVMContext, TVMDeviceType};
pub use errors::*;
pub use function::Function;
pub use module::Module;
pub use ndarray::{empty, NDArray};
pub use ty::TVMType;
pub use value::{TVMArgValue, TVMRetValue};
/// Outputs the current TVM version
pub fn version() -> &'static str {
match str::from_utf8(ts::TVM_VERSION) {
Ok(s) => s,
Err(_) => "Invalid UTF-8 string",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn print_version() {
println!("TVM version: {}", version());
}
#[test]
fn set_error() {
let err = ErrorKind::EmptyArray;
set_last_error(&err.into());
assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string());
}
}