Skip to content

Commit

Permalink
pack A now operates on tensor views
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 12, 2020
1 parent f1e899f commit 7f1b8ae
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 134 deletions.
51 changes: 23 additions & 28 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,37 +78,32 @@ impl ConvUnary {
}
}

// returns an Array of Tensors. shape of the array is [group]
fn kernel_as_packed_as<T: Datum + Copy + Zero>(
&self,
packer: &PackA,
) -> TractResult<ArrayD<Arc<Tensor>>> {
let kernel_g_o_ihw = self.kernel_as_group_o_ihw()?;
let kernel = kernel_g_o_ihw.to_array_view::<T>()?;
let mut packed_as = Array1::from(
kernel
.outer_iter()
.map(|subkernel| {
let mut packed = unsafe {
Tensor::uninitialized_aligned::<T>(&[packer.len()], packer.alignment())?
};
unsafe {
fn kernel_as_packed_as(&self, packer: &PackA) -> TractResult<ArrayD<Arc<Tensor>>> {
let kernel = self.kernel_as_group_o_ihw()?;
unsafe {
let mut packed_as = Array1::from(
(0..self.group)
.map(|g| {
let mut packed = Tensor::uninitialized_aligned_dt(
kernel.datum_type(),
&[packer.len()],
packer.alignment(),
)?;
packer.pack(
&mut TensorViewMut::at_prefix(&mut packed, &[]),
subkernel.as_ptr() as _,
subkernel.strides()[0],
subkernel.strides()[1],
&TensorView::at_prefix(&kernel, &[g]),
false,
);
}
Ok(packed.into_arc_tensor())
})
.collect::<TractResult<Vec<_>>>()?,
)
.into_dyn();
if self.pool_spec.data_format.has_n() {
packed_as.insert_axis_inplace(Axis(0));
Ok(packed.into_arc_tensor())
})
.collect::<TractResult<Vec<_>>>()?,
)
.into_dyn();
if self.pool_spec.data_format.has_n() {
packed_as.insert_axis_inplace(Axis(0));
}
Ok(packed_as.into_dyn())
}
Ok(packed_as)
}

fn bias_as_non_linear<T>(&self) -> TractResult<Option<ArrayD<Vec<FusedSpec>>>>
Expand Down Expand Up @@ -275,7 +270,7 @@ impl ConvUnary {
strides.insert(0, *output_shape.n_stride().unwrap() as isize);
}

let kernels = self.kernel_as_packed_as::<TA>(&mmm.a_pack())?;
let kernels = self.kernel_as_packed_as(&mmm.a_pack())?;
wire = model.wire_node(
format!("{}.matmatmul", name),
matmul::lir::MatMatMulUnaryFinite::<TA, TB, TC, TI> {
Expand Down
48 changes: 23 additions & 25 deletions core/src/ops/matmul/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ where
TI: Datum + Copy + Add + Mul + Zero + fmt::Debug,
MMM: Fn(usize, usize, usize) -> Box<dyn MatMatMul>,
{
use tract_linalg::frame::PackA;
use tract_linalg::frame::PackB;
unsafe {
let rank = a.rank();
Expand All @@ -127,21 +126,19 @@ where
Tensor::uninitialized_aligned_dt(b.datum_type(), &[b_pack.len()], b_pack.alignment())?;

for prefix in indices(&c_shape[..rank - 2]).into_iter() {
let mut pa = a.as_bytes().as_ptr();
let mut pb = b.as_bytes().as_ptr();
let mut c = c.view_mut();
let mut a_prefix = tvec!();
for (axis, &dim) in prefix.slice().iter().enumerate() {
let d = dim.min(a.shape()[axis] - 1);
pa = pa.offset((a.strides()[axis] * d * a.datum_type().size_of()) as isize);
a_prefix.push(dim.min(a.shape()[axis] - 1));
let d = dim.min(b.shape()[axis] - 1);
pb = pb.offset((b.strides()[axis] * d * b.datum_type().size_of()) as isize);
c.slice_axis_inplace(Axis(axis), (dim..=dim).into());
}
a_pack.pack(
&mut TensorViewMut::at_prefix(&mut packed_a, &[]),
pa as _,
a.strides()[prefix.ndim() + a_trans as usize] as isize,
a.strides()[prefix.ndim() + !a_trans as usize] as isize,
&TensorView::at_prefix(&a, &a_prefix),
a_trans,
);
fn pack_b<T: Datum + Copy>(
packer: &PackB,
Expand Down Expand Up @@ -701,24 +698,19 @@ where

let mut mm = mmm(m, k, n);
let c_shape = compute_shape(&a.shape(), b_shape, a_trans, b_trans, c_trans)?;
let a = a.to_array_view::<TA>()?;
let packed_as = Array::from_shape_fn(&a.shape()[0..a.ndim() - 2], |a_prefix| {
let mut a = a.view();
for x in a_prefix.slice() {
a.index_axis_inplace(Axis(0), *x);
}
unsafe {
let mut pa =
Tensor::uninitialized_aligned::<TA>(&[mm.a_pack().len()], mm.a_pack().alignment())
.unwrap();
mm.a_pack().pack(
&mut TensorViewMut::at_prefix(&mut pa, &[]),
a.as_ptr() as _,
a.strides()[a_trans as usize],
a.strides()[!a_trans as usize],
);
pa.into_arc_tensor()
}
let packed_as = Array::from_shape_fn(&a.shape()[0..a.rank() - 2], |a_prefix| unsafe {
let mut pa = Tensor::uninitialized_aligned_dt(
a.datum_type(),
&[mm.a_pack().len()],
mm.a_pack().alignment(),
)
.unwrap();
mm.a_pack().pack(
&mut TensorViewMut::at_prefix(&mut pa, &[]),
&TensorView::at_prefix(&a, a_prefix.slice()),
a_trans,
);
pa.into_arc_tensor()
});
unsafe {
if n == 1 {
Expand Down Expand Up @@ -820,6 +812,12 @@ mod test {

#[test]
fn bin() {
// 0
// 1
// 2
//
// 0 1 2 5
// 3 4 5 14
let a = rctensor2(&[[0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let b = rctensor2(&[[0f32], [1.0], [2.0]]);
let c = rctensor2(&[[5f32], [14.0]]);
Expand Down
20 changes: 17 additions & 3 deletions data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,16 @@ impl Tensor {
Ok(tensor)
}

pub unsafe fn clear<T: Datum + num_traits::Zero>(&mut self) {
self.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = T::zero());
}

pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> anyhow::Result<Tensor> {
let mut t = unsafe { Tensor::uninitialized::<T>(shape)? };
t.as_slice_mut::<T>().unwrap().iter_mut().for_each(|item| *item = T::zero());
Ok(t)
unsafe {
let mut t = Tensor::uninitialized::<T>(shape)?;
t.clear::<T>();
Ok(t)
}
}

pub fn zero_dt(dt: DatumType, shape: &[usize]) -> anyhow::Result<Tensor> {
Expand Down Expand Up @@ -737,6 +743,14 @@ impl Tensor {
}
dispatch_datum!(slice_t(self.datum_type())(&self, axis, start, end))
}

pub fn view(&self) -> view::TensorView {
unsafe { view::TensorView::at_prefix(self, &[]) }
}

pub fn view_mut(&mut self) -> view::TensorViewMut {
unsafe { view::TensorViewMut::at_prefix(self, &[]) }
}
}

impl PartialEq for Tensor {
Expand Down
87 changes: 87 additions & 0 deletions data/src/tensor/view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use crate::tensor::*;

pub struct TensorView<'a> {
tensor: &'a Tensor,
offset: usize,
prefix_len: usize,
}

impl<'a> TensorView<'a> {
pub unsafe fn at_prefix(tensor: &'a Tensor, prefix: &[usize]) -> TensorView<'a> {
let offset = prefix.iter().zip(tensor.strides()).map(|(a, b)| a * b).sum();
TensorView { tensor, prefix_len: prefix.len(), offset }
}

pub fn datum_type(&self) -> DatumType {
self.tensor.datum_type()
}

pub fn rank(&self) -> usize {
self.tensor.rank() - self.prefix_len
}

pub fn shape(&self) -> &[usize] {
&self.tensor.shape()[self.prefix_len..]
}

/// Access the data as a pointer.
pub fn as_ptr<D: Datum>(&self) -> anyhow::Result<*const D> {
self.tensor.check_for_access::<D>()?;
Ok(unsafe { self.as_ptr_unchecked() })
}

/// Access the data as a pointer.
pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
(self.tensor.data as *const u8)
.offset(self.offset as isize * self.tensor.datum_type().size_of() as isize)
as *const D
}
}

pub struct TensorViewMut<'a> {
tensor: &'a mut Tensor,
offset: usize,
prefix_len: usize,
}

impl<'a> TensorViewMut<'a> {
pub unsafe fn at_prefix(tensor: &'a mut Tensor, prefix: &[usize]) -> TensorViewMut<'a> {
let offset = prefix.iter().zip(tensor.strides()).map(|(a, b)| a * b).sum();
TensorViewMut { tensor, prefix_len: prefix.len(), offset }
}

pub fn datum_type(&self) -> DatumType {
self.tensor.datum_type()
}

pub fn shape(&self) -> &[usize] {
&self.tensor.shape()[self.prefix_len..]
}

pub fn rank(&self) -> usize {
self.tensor.rank() - self.prefix_len
}

/// Access the data as a pointer.
pub fn as_ptr<D: Datum>(&self) -> anyhow::Result<*const D> {
self.tensor.check_for_access::<D>()?;
Ok(unsafe { self.as_ptr_unchecked() })
}

/// Access the data as a pointer.
pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
(self.tensor.data as *const u8)
.offset(self.offset as isize * self.tensor.datum_type().size_of() as isize)
as *const D
}

/// Access the data as a pointer.
pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
self.as_ptr_unchecked::<D>() as *mut D
}

/// Access the data as a mutable pointer.
pub fn as_ptr_mut<D: Datum>(&mut self) -> anyhow::Result<*mut D> {
self.as_ptr::<D>().map(|p| p as *mut D)
}
}
43 changes: 30 additions & 13 deletions linalg/benches/mm_for_wavenet_hw.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
#[macro_use]
extern crate criterion;
extern crate tract_linalg;
use criterion::Criterion;

pub fn vec(len: usize, align: usize) -> *mut f32 {
let layout =
std::alloc::Layout::from_size_align(len * std::mem::size_of::<f32>(), align).unwrap();
unsafe { std::alloc::alloc_zeroed(layout) as *mut f32 }
use tract_data::internal::*;

pub fn vec(len: usize, align: usize) -> Tensor {
unsafe {
let mut tensor = Tensor::uninitialized_aligned::<f32>(&[len], align).unwrap();
tensor.clear::<f32>();
tensor
}
}

fn pack_a(c: &mut Criterion, m: usize, k: usize, n: usize) {
c.bench_function(&format!("pack_a_{}x{}x{}", m, k, n), move |b| {
let mm = (tract_linalg::ops().mmm_f32)(m, k, n);
let a = vec(m * k, 4);
let pa = vec(mm.a_pack().len(), mm.a_pack().alignment());
b.iter(move || mm.a_pack().pack(pa, a, k as _, 1))
let a = Tensor::zero::<f32>(&[m, k]).unwrap();
let mut pa = vec(mm.a_pack().len(), mm.a_pack().alignment());
b.iter(move || unsafe { mm.a_pack().pack(&mut pa.view_mut(), &a.view(), false) })
});
}

fn pack_b(c: &mut Criterion, m: usize, k: usize, n: usize) {
c.bench_function(&format!("pack_b_{}x{}x{}", m, k, n), move |be| {
c.bench_function(&format!("pack_b_{}x{}x{}", m, k, n), move |be| unsafe {
let mm = (tract_linalg::ops().mmm_f32)(m, k, n);
let b = vec(n * k, 4);
let pb = vec(mm.b_pack().len(), mm.b_pack().alignment());
be.iter(move || mm.b_pack().pack(pb, b, n as _, 1))
let b = Tensor::zero::<f32>(&[k, n]).unwrap();
let mut pb = Tensor::uninitialized_aligned::<f32>(&[n * k], 4).unwrap();
be.iter(move || {
mm.b_pack().pack(
pb.as_ptr_mut_unchecked::<f32>(),
b.as_ptr_unchecked::<f32>(),
n as _,
1,
)
})
});
}

Expand All @@ -34,7 +44,14 @@ fn mat_mul_prepacked(c: &mut Criterion, m: usize, k: usize, n: usize) {
let pb = vec(mm.b_pack().len(), mm.b_pack().alignment());
let mut c = vec![0.0; m * n];
unsafe {
be.iter(move || mm.run(pa as _, pb as _, c.as_mut_ptr() as _, &[]));
be.iter(move || {
mm.run(
pa.as_ptr_unchecked::<f32>() as _,
pb.as_ptr_unchecked::<f32>() as _,
c.as_mut_ptr() as _,
&[],
)
});
}
});
}
Expand Down
Loading

0 comments on commit 7f1b8ae

Please sign in to comment.