diff --git a/core/benches/im2col_inception.rs b/core/benches/im2col_inception.rs index 35d75350e4..428ac7f456 100644 --- a/core/benches/im2col_inception.rs +++ b/core/benches/im2col_inception.rs @@ -48,7 +48,7 @@ fn b( unsafe { unary.wire_as_im2col_pair(&mut m, "", wire, false).unwrap(); } - let im2col = m.node(1).op_as::>().unwrap(); + let im2col = m.node(1).op_as::().unwrap(); let args = tvec!(image.into()); c.bench_function(name, move |b| { b.iter(|| im2col.eval(args.clone()).unwrap()) diff --git a/core/src/ops/cnn/conv/im2col.rs b/core/src/ops/cnn/conv/im2col.rs index fc6fe0872a..1dc61cfc9f 100644 --- a/core/src/ops/cnn/conv/im2col.rs +++ b/core/src/ops/cnn/conv/im2col.rs @@ -10,7 +10,7 @@ use num_traits::Zero; #[derive(Debug, Clone, Educe)] #[educe(Hash)] -pub struct Im2Col { +pub struct Im2Col { pub patch: Patch, pub input_shape: DataShape, pub output_shape: DataShape, @@ -19,19 +19,19 @@ pub struct Im2Col { pub n: usize, pub group: usize, pub ci_per_group: usize, - pub b_pack: PackB, + pub b_pack: PackB, patcher: Patcher, pad_value: Tensor, } -impl DynHash for Im2Col { +impl DynHash for Im2Col { fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { dyn_hash(self, state) } } -impl PartialEq for Im2Col { - fn eq(&self, other: &Im2Col) -> bool { +impl PartialEq for Im2Col { + fn eq(&self, other: &Im2Col) -> bool { self.patch == other.patch && self.m == other.m && self.n == other.n @@ -42,7 +42,7 @@ impl PartialEq for Im2Col { } } -impl Im2Col { +impl Im2Col { pub fn new( patch: Patch, input_shape: DataShape, @@ -51,9 +51,9 @@ impl Im2Col { n: usize, group: usize, ci_per_group: usize, - b_pack: PackB, - pad_value: T, - ) -> TractResult> { + b_pack: PackB, + pad_value: Tensor, + ) -> TractResult { let patcher = if !patch.padded && patch.rank() == 2 { Patcher::Valid2d } else if patch.rank() == 2 { @@ -68,7 +68,6 @@ impl Im2Col { group, b_pack.len() ))?; - let pad_value = tensor0(pad_value); Ok(Im2Col { patch, input_shape, @@ -88,17 +87,19 @@ impl Im2Col { &self.output_shape.shape } - pub(super) fn im2col<'i>(&'i self, input: &'i ArrayViewD<'i, T>) -> TractResult { - let mut packed = unsafe { - Tensor::uninitialized_aligned::(&*self.output_shape.shape, self.b_pack.alignment())? - }; + pub(super) unsafe fn im2col<'i, T: Copy + Datum>( + &'i self, + input: &'i Tensor, + packed: &mut Tensor, + ) { + let input = input.to_array_view_unchecked::(); if self.output_shape.shape.iter().any(|d| d.is_zero()) { - return Ok(packed); + return; } - let pad_value = *self.pad_value.to_scalar()?; + let pad_value = *self.pad_value.to_scalar_unchecked(); for i in 0..*self.input_shape.n_dim().unwrap_or(&1) { for g in 0..self.group { - let mut packed = packed.to_array_view_mut::()?; + let mut packed = packed.to_array_view_mut_unchecked::(); packed.slice_axis_inplace(Axis(0), (i..=i).into()); packed.slice_axis_inplace(Axis(1), (g..=g).into()); let input = if let Some(ref n_axis) = self.input_shape.n_axis() { @@ -109,11 +110,10 @@ impl Im2Col { self.patcher.patch(self, &input, packed.as_slice_mut().unwrap(), g, pad_value); } } - Ok(packed) } } -impl Op for Im2Col { +impl Op for Im2Col { fn name(&self) -> Cow { "Im2col".into() } @@ -132,22 +132,33 @@ impl Op for Im2Col { op_as_typed_op!(); } -impl EvalOp for Im2Col { +impl EvalOp for Im2Col { fn is_stateless(&self) -> bool { true } fn eval(&self, inputs: TVec>) -> TractResult>> { - let tensor = self.im2col(&inputs[0].to_array_view()?)?; - Ok(tvec!(tensor.into())) + unsafe { + let mut tensor = Tensor::uninitialized_aligned_dt( + inputs[0].datum_type(), + &*self.output_shape.shape, + self.b_pack.alignment(), + )?; + dispatch_copy_by_size!(Self::im2col(inputs[0].datum_type())( + self, + &inputs[0], + &mut tensor + )); + Ok(tvec!(tensor.into())) + } } } -impl TypedOp for Im2Col { +impl TypedOp for Im2Col { as_op!(); - fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult> { - Ok(tvec!(TypedFact::dt_shape(T::datum_type(), &*self.output_shape.shape)?)) + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + Ok(tvec!(TypedFact::dt_shape(inputs[0].datum_type, &*self.output_shape.shape)?)) } } @@ -160,9 +171,9 @@ enum Patcher { } impl Patcher { - fn patch<'i, 'p, T: Copy + Datum + Zero>( + fn patch<'i, 'p, T: Copy + Datum>( &self, - im2col: &'i Im2Col, + im2col: &'i Im2Col, input: &'i ArrayViewD<'i, T>, pack: &'p mut [T], g: usize, @@ -193,8 +204,8 @@ impl Patcher { } #[inline(never)] - fn generic<'i, 'p, T: Copy + Datum + Zero>( - im2col: &'i Im2Col, + fn generic<'i, 'p, T: Copy + Datum>( + im2col: &'i Im2Col, input: &'i ArrayViewD<'i, T>, pack: &'p mut [T], g: usize, @@ -228,8 +239,8 @@ impl Patcher { } #[inline(never)] - fn valid_1d<'i, 'p, T: Copy + Datum + Zero>( - im2col: &'i Im2Col, + fn valid_1d<'i, 'p, T: Copy + Datum>( + im2col: &'i Im2Col, input: &'i ArrayView2<'i, T>, pack: &'p mut [T], g: usize, @@ -253,8 +264,8 @@ impl Patcher { } #[inline(never)] - fn padded_2d<'i, 'p, T: Copy + Datum + Zero>( - im2col: &'i Im2Col, + fn padded_2d<'i, 'p, T: Copy + Datum>( + im2col: &'i Im2Col, input: &'i ArrayView3<'i, T>, pack: &'p mut [T], g: usize, @@ -302,8 +313,8 @@ impl Patcher { } #[inline(never)] - fn valid_2d<'i, 'p, T: Copy + Datum + Zero>( - im2col: &'i Im2Col, + fn valid_2d<'i, 'p, T: Copy + Datum>( + im2col: &'i Im2Col, input: &'i ArrayView3<'i, T>, pack: &'p mut [T], g: usize, diff --git a/core/src/ops/cnn/conv/unary.rs b/core/src/ops/cnn/conv/unary.rs index bd8a6aa313..1966e0d3da 100644 --- a/core/src/ops/cnn/conv/unary.rs +++ b/core/src/ops/cnn/conv/unary.rs @@ -81,7 +81,7 @@ impl ConvUnary { fn kernel_as_packed_as( &self, - packer: &PackA, + packer: &PackA, ) -> TractResult>> { let kernel = self.kernel_as_group_o_ihw()?; let kernel = kernel.to_array_view::()?; @@ -232,12 +232,14 @@ impl ConvUnary { self.group, c_dim / self.group, mmm.as_mmm().b_pack(), - self.q_params - .as_ref() - .and_then(|q| q.zero_point_b.as_ref()) - .map(|t| t.to_scalar::().map(|x| *x)) - .transpose()? - .unwrap_or(TB::default()), + tensor0( + self.q_params + .as_ref() + .and_then(|q| q.zero_point_b.as_ref()) + .map(|t| t.to_scalar::().map(|x| *x)) + .transpose()? + .unwrap_or(TB::zero()), + ), )?, &[wire], )?[0]; @@ -263,7 +265,7 @@ impl ConvUnary { bc_c_shape: output_shape.shape.clone(), c_fact: TypedFact::dt_shape(TC::datum_type(), &*output_shape.shape)?, c_prefix_dim_and_stride, - packed_as: self.kernel_as_packed_as(&mmm.as_mmm().a_pack())?, + packed_as: self.kernel_as_packed_as::(&mmm.as_mmm().a_pack())?, fused_ops: self.bias_as_non_linear()?, mmm, }, diff --git a/core/src/ops/matmul/pack_b.rs b/core/src/ops/matmul/pack_b.rs index a8f774e46a..fba92d824f 100644 --- a/core/src/ops/matmul/pack_b.rs +++ b/core/src/ops/matmul/pack_b.rs @@ -1,5 +1,3 @@ -use num_traits::Zero; - use crate::internal::*; use ndarray::*; @@ -7,29 +5,20 @@ use tract_linalg::frame::PackB; #[derive(Debug, Clone, PartialEq, Educe)] #[educe(Hash)] -pub struct MatMatMulPackB -where - T: Copy + Datum + Zero, -{ - pub(crate) pack_b: PackB, +pub struct MatMatMulPackB { + pub(crate) pack_b: PackB, pub(crate) row_stride: isize, pub(crate) col_stride: isize, pub(crate) output_shape: TVec, } -impl DynHash for MatMatMulPackB -where - T: Copy + Datum + Zero, -{ +impl DynHash for MatMatMulPackB { fn dyn_hash(&self, hasher: &mut dyn std::hash::Hasher) { dyn_hash(&self, hasher) } } -impl Op for MatMatMulPackB -where - T: Copy + Datum + Zero, -{ +impl Op for MatMatMulPackB { fn name(&self) -> Cow { "MatMatMulPackB".into() } @@ -42,43 +31,42 @@ where op_as_typed_op!(); } -impl EvalOp for MatMatMulPackB -where - T: Copy + Datum + Zero, -{ +impl MatMatMulPackB { + unsafe fn pack_t(&self, b: &Tensor, packed: &mut Tensor) { + let b_prefix = &b.shape()[..b.shape().len() - 2]; + let b = b.to_array_view_unchecked::(); + for prefix in indices(b_prefix).into_iter() { + let mut b = b.view(); + let mut p = packed.to_array_view_mut_unchecked(); + for &dim in prefix.slice() { + b.index_axis_inplace(Axis(0), dim); + p.index_axis_inplace(Axis(0), dim); + } + self.pack_b.pack(p.as_mut_ptr(), b.as_ptr(), self.row_stride, self.col_stride) + } + } +} + +impl EvalOp for MatMatMulPackB { fn is_stateless(&self) -> bool { true } fn eval(&self, mut inputs: TVec>) -> TractResult>> { let b = args_1!(inputs); + let dt = b.datum_type(); let mut packed = unsafe { - Tensor::uninitialized_aligned::(&*self.output_shape, self.pack_b.alignment()) + Tensor::uninitialized_aligned_dt(dt, &*self.output_shape, self.pack_b.alignment()) .unwrap() }; - if b.shape()[..b.shape().len() - 2].iter().any(|d| *d > 1) { - let b = b.to_array_view::()?; - let b_prefix = &b.shape()[..b.shape().len() - 2]; - for prefix in indices(b_prefix).into_iter() { - let mut b = b.view(); - let mut p = packed.to_array_view_mut()?; - for &dim in prefix.slice() { - b.index_axis_inplace(Axis(0), dim); - p.index_axis_inplace(Axis(0), dim); - } - self.pack_b.pack(p.as_mut_ptr(), b.as_ptr(), self.row_stride, self.col_stride) - } - } else { - self.pack_b.pack(packed.as_ptr_mut()?, b.as_ptr()?, self.row_stride, self.col_stride) + unsafe { + dispatch_copy_by_size!(Self::pack_t(dt)(self, &b, &mut packed)); } Ok(tvec!(packed.into_arc_tensor())) } } -impl TypedOp for MatMatMulPackB -where - T: Copy + Datum + Zero, -{ +impl TypedOp for MatMatMulPackB { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { Ok(tvec!(TypedFact::dt_shape(inputs[0].datum_type, &*self.output_shape)?)) } diff --git a/linalg/src/frame/lut.rs b/linalg/src/frame/lut.rs index 71c8db8bed..2d11644224 100644 --- a/linalg/src/frame/lut.rs +++ b/linalg/src/frame/lut.rs @@ -42,35 +42,28 @@ where } fn run(&self, buf: &mut [u8]) { - let align = K::input_alignment_bytes(); - let aligned_start = (buf.as_ptr() as usize + align - 1) / align * align; - let prefix = (aligned_start - buf.as_ptr() as usize).min(buf.len()); - for i in 0..(prefix as isize) { - unsafe { + unsafe { + let table: *const u8 = self.table.as_ptr_unchecked(); + let align = K::input_alignment_bytes(); + let aligned_start = (buf.as_ptr() as usize + align - 1) / align * align; + let prefix = (aligned_start - buf.as_ptr() as usize).min(buf.len()); + for i in 0..(prefix as isize) { let ptr = buf.as_mut_ptr().offset(i); - *ptr = self.table.as_slice_unchecked()[*ptr as usize]; + *ptr = *table.offset(*ptr as isize); } - } - let remaining = buf.len() - prefix; - if remaining == 0 { - return; - } - let n = K::n(); - let aligned_len = remaining / n * n; - if aligned_len > 0 { - unsafe { - K::run( - buf.as_mut_ptr().offset(prefix as isize), - aligned_len, - self.table.as_ptr_unchecked(), - ); + let remaining = buf.len() - prefix; + if remaining == 0 { + return; } - } - let remaining = buf.len() - aligned_len - prefix; - for i in 0..remaining { - unsafe { + let n = K::n(); + let aligned_len = remaining / n * n; + if aligned_len > 0 { + K::run(buf.as_mut_ptr().offset(prefix as isize), aligned_len, table); + } + let remaining = buf.len() - aligned_len - prefix; + for i in 0..remaining { let ptr = buf.as_mut_ptr().offset((i + prefix + aligned_len) as isize); - *ptr = self.table.as_slice_unchecked()[*ptr as usize]; + *ptr = *table.offset(*ptr as isize); } } } diff --git a/linalg/src/frame/mmm/mmm.rs b/linalg/src/frame/mmm/mmm.rs index b205a13fc5..6586e07954 100644 --- a/linalg/src/frame/mmm/mmm.rs +++ b/linalg/src/frame/mmm/mmm.rs @@ -19,8 +19,8 @@ where TC: Copy + Debug + 'static, TI: Copy + Add + Mul + Zero + Debug + 'static, { - fn a_pack(&self) -> PackA; - fn b_pack(&self) -> PackB; + fn a_pack(&self) -> PackA; + fn b_pack(&self) -> PackB; fn a_storage(&self) -> &MatrixStoreSpec; fn b_storage(&self) -> &MatrixStoreSpec; @@ -137,11 +137,11 @@ where TI: Copy + Add + Mul + Zero + Debug + 'static, K: MatMatMulKer + 'static, { - fn a_pack(&self) -> PackA { + fn a_pack(&self) -> PackA { PackA::new(self.k, self.m, K::mr(), K::alignment_bytes_packed_a()) } - fn b_pack(&self) -> PackB { + fn b_pack(&self) -> PackB { PackB::new(self.k, self.n, K::nr(), K::alignment_bytes_packed_b()) } diff --git a/linalg/src/frame/pack_a.rs b/linalg/src/frame/pack_a.rs index 79b501cf16..72d8c4d0e0 100644 --- a/linalg/src/frame/pack_a.rs +++ b/linalg/src/frame/pack_a.rs @@ -1,20 +1,17 @@ -use num_traits::Zero; use std::fmt::Debug; -use std::marker::PhantomData; #[derive(Clone, Debug, Eq, PartialEq, Educe)] #[educe(Hash)] -pub struct PackA { +pub struct PackA { k: usize, m: usize, mr: usize, alignment: usize, - _boo: PhantomData, } -impl PackA { - pub fn new(k: usize, m: usize, mr: usize, alignment: usize) -> PackA { - PackA { k, m, mr, alignment, _boo: PhantomData } +impl PackA { + pub fn new(k: usize, m: usize, mr: usize, alignment: usize) -> PackA { + PackA { k, m, mr, alignment } } pub fn alignment(&self) -> usize { self.alignment @@ -24,7 +21,7 @@ impl PackA { (self.m + self.mr - 1) / self.mr * self.mr * self.k } - fn pack_panel_a(&self, pa: *mut T, a: *const T, rsa: isize, csa: isize, rows: usize) { + fn pack_panel_a(&self, pa: *mut T, a: *const T, rsa: isize, csa: isize, rows: usize) { let mr = self.mr; for i in 0..self.k { for j in 0..rows { @@ -33,16 +30,10 @@ impl PackA { *a.offset(i as isize * csa + j as isize * rsa) } } - #[cfg(debug_assertions)] - for j in rows..mr { - unsafe { - *pa.offset((i * mr + j) as isize) = T::zero(); - } - } } } - pub fn pack(&self, pa: *mut T, a: *const T, rsa: isize, csa: isize) { + pub fn pack(&self, pa: *mut T, a: *const T, rsa: isize, csa: isize) { let mr = self.mr; assert!(pa as usize % self.alignment == 0); unsafe { diff --git a/linalg/src/frame/pack_b.rs b/linalg/src/frame/pack_b.rs index 05e47b8e42..8bc01133f8 100644 --- a/linalg/src/frame/pack_b.rs +++ b/linalg/src/frame/pack_b.rs @@ -1,20 +1,17 @@ -use num_traits::Zero; -use std::fmt::Debug; use std::marker::PhantomData; #[derive(Clone, Debug, Eq, PartialEq, Educe)] #[educe(Hash)] -pub struct PackB { +pub struct PackB { k: usize, n: usize, nr: usize, alignment: usize, - _boo: PhantomData, } -impl PackB { - pub fn new(k: usize, n: usize, nr: usize, alignment: usize) -> PackB { - PackB { k, n, nr, alignment, _boo: PhantomData } +impl PackB { + pub fn new(k: usize, n: usize, nr: usize, alignment: usize) -> PackB { + PackB { k, n, nr, alignment } } pub fn alignment(&self) -> usize { @@ -25,7 +22,7 @@ impl PackB { (self.n + self.nr - 1) / self.nr * self.nr * self.k } - pub fn pack(&self, pb: *mut T, b: *const T, rsb: isize, csb: isize) { + pub fn pack(&self, pb: *mut T, b: *const T, rsb: isize, csb: isize) { let nr = self.nr; assert!(pb as usize % self.alignment == 0); unsafe { @@ -50,7 +47,7 @@ impl PackB { } } - fn pack_panel_b(&self, pb: *mut T, b: *const T, rsb: isize, csb: isize, cols: usize) { + fn pack_panel_b(&self, pb: *mut T, b: *const T, rsb: isize, csb: isize, cols: usize) { let nr = self.nr; for i in 0..self.k { for j in 0..cols { @@ -59,16 +56,10 @@ impl PackB { *b.offset(j as isize * csb + i as isize * rsb) } } - #[cfg(debug_assertions)] - for j in cols..nr { - unsafe { - *pb.offset((i * nr + j) as isize) = T::zero(); - } - } } } - pub fn write_packed_by_rows<'p>(&self, pb: &'p mut [T]) -> PackedWriter<'p, T> { + pub fn write_packed_by_rows<'p, T: Copy>(&self, pb: &'p mut [T]) -> PackedWriter<'p, T> { PackedWriter::new(pb, self.nr, self.n, self.k) } } @@ -76,7 +67,7 @@ impl PackB { #[derive(Debug)] pub struct PackedWriter<'p, T> where - T: Copy + Debug, + T: Copy, { ptr: *mut T, panels: usize, @@ -91,7 +82,7 @@ where impl<'p, T> PackedWriter<'p, T> where - T: Copy + Debug, + T: Copy, { pub fn new(data: &'p mut [T], panel_width: usize, mn: usize, k: usize) -> PackedWriter<'p, T> { let panels = (mn + panel_width - 1) / panel_width;