This repository was archived by the owner on Feb 26, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathndarray.rs
362 lines (327 loc) · 11.5 KB
/
ndarray.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
//! This module implements the [`NDArray`] type for working with *TVM tensors* or
//! coverting from a Rust's ndarray to TVM `NDArray`.
//!
//! One can create an empty NDArray given the shape, context and dtype using [`empty`].
//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`].
//! To copy an NDArray to different context use [`copy_to_ctx`].
//!
//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows:
//!
//! # Example
//!
//! ```
//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
//! .unwrap()
//! .into_dyn(); // Rust's ndarray
//! let nd = NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float")).unwrap();
//! assert_eq!(nd.shape(), Some(&mut [2, 2]));
//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
//! assert!(rnd.all_close(&a, 1e-8f32));
//! ```
//!
//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/
//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice};
use num_traits::Num;
use rust_ndarray::{Array, ArrayD};
use ts;
use Error;
use ErrorKind;
use Result;
use TVMByteArray;
use TVMContext;
use TVMType;
/// See the [`module-level documentation`](../ndarray/index.html) for more details.
///
/// Wrapper around TVM array handle.
#[derive(Debug)]
pub struct NDArray {
pub(crate) handle: ts::TVMArrayHandle,
is_view: bool,
}
impl NDArray {
pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self {
NDArray {
handle: handle,
is_view: is_view,
}
}
/// Returns the underlying array handle.
pub fn handle(&self) -> ts::TVMArrayHandle {
self.handle
}
pub fn is_view(&self) -> bool {
self.is_view
}
/// Returns the shape of the NDArray.
pub fn shape(&self) -> Option<&mut [usize]> {
let arr = unsafe { *(self.handle) };
if arr.shape.is_null() || arr.data.is_null() {
return None;
};
let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) };
Some(slc)
}
/// Returns the total number of entries of the NDArray.
pub fn size(&self) -> Option<usize> {
self.shape()
.map(|v| v.into_iter().fold(1, |acc, &mut e| acc * e))
}
/// Returns the context which the NDArray was defined.
pub fn ctx(&self) -> TVMContext {
unsafe { (*self.handle).ctx.into() }
}
/// Returns the type of the entries of the NDArray.
pub fn dtype(&self) -> TVMType {
unsafe { (*self.handle).dtype.into() }
}
/// Returns the number of dimensions of the NDArray.
pub fn ndim(&self) -> usize {
unsafe { (*self.handle).ndim as usize }
}
/// Returns the strides of the underlying NDArray.
pub fn strides(&self) -> Option<&[usize]> {
unsafe {
let sz = self.ndim() * mem::size_of::<usize>();
let slc = slice::from_raw_parts((*self.handle).strides as *const usize, sz);
Some(slc)
}
}
/// Shows whether the underlying ndarray is contiguous in memory or not.
pub fn is_contiguous(&self) -> Result<bool> {
Ok(match self.strides() {
None => true,
Some(strides) => {
// MissingShapeError in case shape is not determined
self.shape()?
.iter()
.zip(strides)
.rfold(
(true, 1),
|(is_contig, expected_stride), (shape, stride)| {
(
is_contig && *stride == expected_stride,
expected_stride * (*shape as usize),
)
},
)
.0
}
})
}
pub fn byte_offset(&self) -> isize {
unsafe { (*self.handle).byte_offset as isize }
}
/// Flattens the NDArray to a `Vec` of the same type in cpu.
///
/// ## Example
///
/// ```
/// let shape = &mut [4];
/// let mut data = vec![1i32, 2, 3, 4];
/// let ctx = TVMContext::cpu(0);
/// let mut ndarray = empty(shape, ctx, TVMType::from("int"));
/// ndarray.copy_from_buffer(&mut data);
/// assert_eq!(ndarray.shape(), Some(shape));
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ```
pub fn to_vec<T>(&self) -> Result<Vec<T>> {
if self.shape().is_none() {
bail!("{}", ErrorKind::EmptyArray);
}
let earr = empty(self.shape()?, TVMContext::cpu(0), self.dtype());
let target = self.copy_to_ndarray(earr)?;
let arr = unsafe { *(target.handle) };
let sz = self.size()? as usize;
let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
unsafe {
v.as_mut_ptr()
.copy_from_nonoverlapping(arr.data as *const T, sz);
v.set_len(sz);
}
Ok(v)
}
/// Converts the NDArray to [`TVMByteArray`].
pub fn to_bytearray(&self) -> Result<TVMByteArray> {
let v = self.to_vec::<u8>()?;
Ok(TVMByteArray::from(&v))
}
/// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
///
/// ## Example
///
/// ```
/// let shape = &mut [2];
/// let mut data = vec![1f32, 2];
/// let ctx = TVMContext::gpu(0);
/// let mut ndarray = empty(shape, ctx, TVMType::from("int"));
/// ndarray.copy_from_buffer(&mut data);
/// ```
///
/// *Note*: if something goes wrong during the copy, it will panic
/// from TVM side. See `TVMArrayCopyFromBytes` in `c_runtime_api.h`.
pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
check_call!(ts::TVMArrayCopyFromBytes(
self.handle,
data.as_ptr() as *mut _,
data.len() * mem::size_of::<T>()
));
}
/// Copies the NDArray to another target NDArray.
pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> {
if self.dtype() != target.dtype() {
bail!(
"{}",
ErrorKind::TypeMismatch(
format!("{}", self.dtype().to_string()),
format!("{}", target.dtype().to_string()),
)
);
}
check_call!(ts::TVMArrayCopyFromTo(
self.handle,
target.handle,
ptr::null_mut() as ts::TVMStreamHandle
));
Ok(target)
}
/// Copies the NDArray to a target context.
pub fn copy_to_ctx(&self, target: &TVMContext) -> Result<NDArray> {
let tmp = empty(self.shape()?, target.clone(), self.dtype());
let copy = self.copy_to_ndarray(tmp)?;
Ok(copy)
}
/// Converts a Rust's ndarray to TVM NDArray.
pub fn from_rust_ndarray<T: Num32 + Copy>(
rnd: &ArrayD<T>,
ctx: TVMContext,
dtype: TVMType,
) -> Result<Self> {
let mut shape = rnd.shape().to_vec();
let mut nd = empty(&mut shape, ctx, dtype);
let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
nd.copy_from_buffer(buf.as_slice_mut()?);
Ok(nd)
}
}
/// Allocates and creates an empty NDArray given the shape, context and dtype.
pub fn empty(shape: &mut [usize], ctx: TVMContext, dtype: TVMType) -> NDArray {
let mut handle = ptr::null_mut() as ts::TVMArrayHandle;
check_call!(ts::TVMArrayAlloc(
shape.as_ptr() as *const i64,
shape.len() as c_int,
dtype.inner.code as c_int,
dtype.inner.bits as c_int,
dtype.inner.lanes as c_int,
ctx.device_type.0 as c_int,
ctx.device_id as c_int,
&mut handle as *mut _,
));
NDArray::new(handle, false)
}
macro_rules! impl_from_ndarray_rustndarray {
($type:ty, $type_name:tt) => {
impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
type Error = Error;
fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> {
if nd.shape().is_none() {
bail!("{}", ErrorKind::EmptyArray);
}
assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
}
}
impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
type Error = Error;
fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> {
if nd.shape().is_none() {
bail!("{}", ErrorKind::EmptyArray);
}
assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch");
Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?)
}
}
};
}
impl_from_ndarray_rustndarray!(i32, "int");
impl_from_ndarray_rustndarray!(u32, "uint");
impl_from_ndarray_rustndarray!(f32, "float");
impl Drop for NDArray {
fn drop(&mut self) {
if !self.is_view {
check_call!(ts::TVMArrayFree(self.handle));
}
}
}
/// A trait for the supported 32bits numerical types in frontend.
pub trait Num32: Num {
const BITS: u8 = 32;
}
macro_rules! impl_num32 {
($($type:ty),+) => {
$(
impl Num32 for $type {}
)+
};
}
impl_num32!(i32, u32, f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basics() {
let shape = &mut [1, 2, 3];
let ctx = TVMContext::cpu(0);
let ndarray = empty(shape, ctx, TVMType::from("int"));
assert_eq!(ndarray.shape().unwrap(), shape);
assert_eq!(
ndarray.size().unwrap(),
shape.to_vec().into_iter().product()
);
assert_eq!(ndarray.ndim(), 3);
assert!(ndarray.strides().is_none());
assert_eq!(ndarray.byte_offset(), 0);
}
#[test]
fn copy() {
let shape = &mut [4];
let mut data = vec![1i32, 2, 3, 4];
let ctx = TVMContext::cpu(0);
let mut ndarray = empty(shape, ctx, TVMType::from("int"));
assert!(ndarray.to_vec::<i32>().is_ok());
ndarray.copy_from_buffer(&mut data);
assert_eq!(ndarray.shape().unwrap(), shape);
assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
assert_eq!(ndarray.ndim(), 1);
assert!(ndarray.is_contiguous().is_ok());
assert_eq!(ndarray.byte_offset(), 0);
let mut shape = vec![4];
let e = empty(&mut shape, TVMContext::cpu(0), TVMType::from("int"));
let nd = ndarray.copy_to_ndarray(e);
assert!(nd.is_ok());
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
}
#[test]
#[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
fn copy_wrong_dtype() {
let mut shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.];
let ctx = TVMContext::cpu(0);
let mut nd_float = empty(&mut shape, ctx.clone(), TVMType::from("float"));
nd_float.copy_from_buffer(&mut data);
let empty_int = empty(&mut shape, ctx, TVMType::from("int"));
nd_float.copy_to_ndarray(empty_int).unwrap();
}
#[test]
fn rust_ndarray() {
let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
.unwrap()
.into_dyn();
let nd =
NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float")).unwrap();
assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
assert!(rnd.all_close(&a, 1e-8f32));
}
}