Skip to content

Commit

Permalink
pack struct is no longer generic
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 21, 2020
1 parent 83a1351 commit 08da9d6
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 144 deletions.
2 changes: 1 addition & 1 deletion core/benches/im2col_inception.rs
Expand Up @@ -48,7 +48,7 @@ fn b(
unsafe {
unary.wire_as_im2col_pair(&mut m, "", wire, false).unwrap();
}
let im2col = m.node(1).op_as::<Im2Col<f32>>().unwrap();
let im2col = m.node(1).op_as::<Im2Col>().unwrap();
let args = tvec!(image.into());
c.bench_function(name, move |b| {
b.iter(|| im2col.eval(args.clone()).unwrap())
Expand Down
81 changes: 46 additions & 35 deletions core/src/ops/cnn/conv/im2col.rs
Expand Up @@ -10,7 +10,7 @@ use num_traits::Zero;

#[derive(Debug, Clone, Educe)]
#[educe(Hash)]
pub struct Im2Col<T: Copy + Datum + Zero> {
pub struct Im2Col {
pub patch: Patch,
pub input_shape: DataShape,
pub output_shape: DataShape,
Expand All @@ -19,19 +19,19 @@ pub struct Im2Col<T: Copy + Datum + Zero> {
pub n: usize,
pub group: usize,
pub ci_per_group: usize,
pub b_pack: PackB<T>,
pub b_pack: PackB,
patcher: Patcher,
pad_value: Tensor,
}

impl<T: Copy + Datum + Zero> DynHash for Im2Col<T> {
impl DynHash for Im2Col {
fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
dyn_hash(self, state)
}
}

impl<T: Copy + Datum + Zero> PartialEq for Im2Col<T> {
fn eq(&self, other: &Im2Col<T>) -> bool {
impl PartialEq for Im2Col {
fn eq(&self, other: &Im2Col) -> bool {
self.patch == other.patch
&& self.m == other.m
&& self.n == other.n
Expand All @@ -42,7 +42,7 @@ impl<T: Copy + Datum + Zero> PartialEq for Im2Col<T> {
}
}

impl<T: Copy + Datum + Zero> Im2Col<T> {
impl Im2Col {
pub fn new(
patch: Patch,
input_shape: DataShape,
Expand All @@ -51,9 +51,9 @@ impl<T: Copy + Datum + Zero> Im2Col<T> {
n: usize,
group: usize,
ci_per_group: usize,
b_pack: PackB<T>,
pad_value: T,
) -> TractResult<Im2Col<T>> {
b_pack: PackB,
pad_value: Tensor,
) -> TractResult<Im2Col> {
let patcher = if !patch.padded && patch.rank() == 2 {
Patcher::Valid2d
} else if patch.rank() == 2 {
Expand All @@ -68,7 +68,6 @@ impl<T: Copy + Datum + Zero> Im2Col<T> {
group,
b_pack.len()
))?;
let pad_value = tensor0(pad_value);
Ok(Im2Col {
patch,
input_shape,
Expand All @@ -88,17 +87,19 @@ impl<T: Copy + Datum + Zero> Im2Col<T> {
&self.output_shape.shape
}

pub(super) fn im2col<'i>(&'i self, input: &'i ArrayViewD<'i, T>) -> TractResult<Tensor> {
let mut packed = unsafe {
Tensor::uninitialized_aligned::<T>(&*self.output_shape.shape, self.b_pack.alignment())?
};
pub(super) unsafe fn im2col<'i, T: Copy + Datum>(
&'i self,
input: &'i Tensor,
packed: &mut Tensor,
) {
let input = input.to_array_view_unchecked::<T>();
if self.output_shape.shape.iter().any(|d| d.is_zero()) {
return Ok(packed);
return;
}
let pad_value = *self.pad_value.to_scalar()?;
let pad_value = *self.pad_value.to_scalar_unchecked();
for i in 0..*self.input_shape.n_dim().unwrap_or(&1) {
for g in 0..self.group {
let mut packed = packed.to_array_view_mut::<T>()?;
let mut packed = packed.to_array_view_mut_unchecked::<T>();
packed.slice_axis_inplace(Axis(0), (i..=i).into());
packed.slice_axis_inplace(Axis(1), (g..=g).into());
let input = if let Some(ref n_axis) = self.input_shape.n_axis() {
Expand All @@ -109,11 +110,10 @@ impl<T: Copy + Datum + Zero> Im2Col<T> {
self.patcher.patch(self, &input, packed.as_slice_mut().unwrap(), g, pad_value);
}
}
Ok(packed)
}
}

impl<T: Copy + Datum + Zero> Op for Im2Col<T> {
impl Op for Im2Col {
fn name(&self) -> Cow<str> {
"Im2col".into()
}
Expand All @@ -132,22 +132,33 @@ impl<T: Copy + Datum + Zero> Op for Im2Col<T> {
op_as_typed_op!();
}

impl<T: Copy + Datum + Zero> EvalOp for Im2Col<T> {
impl EvalOp for Im2Col {
fn is_stateless(&self) -> bool {
true
}

fn eval(&self, inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
let tensor = self.im2col(&inputs[0].to_array_view()?)?;
Ok(tvec!(tensor.into()))
unsafe {
let mut tensor = Tensor::uninitialized_aligned_dt(
inputs[0].datum_type(),
&*self.output_shape.shape,
self.b_pack.alignment(),
)?;
dispatch_copy_by_size!(Self::im2col(inputs[0].datum_type())(
self,
&inputs[0],
&mut tensor
));
Ok(tvec!(tensor.into()))
}
}
}

impl<T: Copy + Datum + Zero> TypedOp for Im2Col<T> {
impl TypedOp for Im2Col {
as_op!();

fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(TypedFact::dt_shape(T::datum_type(), &*self.output_shape.shape)?))
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(TypedFact::dt_shape(inputs[0].datum_type, &*self.output_shape.shape)?))
}
}

Expand All @@ -160,9 +171,9 @@ enum Patcher {
}

impl Patcher {
fn patch<'i, 'p, T: Copy + Datum + Zero>(
fn patch<'i, 'p, T: Copy + Datum>(
&self,
im2col: &'i Im2Col<T>,
im2col: &'i Im2Col,
input: &'i ArrayViewD<'i, T>,
pack: &'p mut [T],
g: usize,
Expand Down Expand Up @@ -193,8 +204,8 @@ impl Patcher {
}

#[inline(never)]
fn generic<'i, 'p, T: Copy + Datum + Zero>(
im2col: &'i Im2Col<T>,
fn generic<'i, 'p, T: Copy + Datum>(
im2col: &'i Im2Col,
input: &'i ArrayViewD<'i, T>,
pack: &'p mut [T],
g: usize,
Expand Down Expand Up @@ -228,8 +239,8 @@ impl Patcher {
}

#[inline(never)]
fn valid_1d<'i, 'p, T: Copy + Datum + Zero>(
im2col: &'i Im2Col<T>,
fn valid_1d<'i, 'p, T: Copy + Datum>(
im2col: &'i Im2Col,
input: &'i ArrayView2<'i, T>,
pack: &'p mut [T],
g: usize,
Expand All @@ -253,8 +264,8 @@ impl Patcher {
}

#[inline(never)]
fn padded_2d<'i, 'p, T: Copy + Datum + Zero>(
im2col: &'i Im2Col<T>,
fn padded_2d<'i, 'p, T: Copy + Datum>(
im2col: &'i Im2Col,
input: &'i ArrayView3<'i, T>,
pack: &'p mut [T],
g: usize,
Expand Down Expand Up @@ -302,8 +313,8 @@ impl Patcher {
}

#[inline(never)]
fn valid_2d<'i, 'p, T: Copy + Datum + Zero>(
im2col: &'i Im2Col<T>,
fn valid_2d<'i, 'p, T: Copy + Datum>(
im2col: &'i Im2Col,
input: &'i ArrayView3<'i, T>,
pack: &'p mut [T],
g: usize,
Expand Down
18 changes: 10 additions & 8 deletions core/src/ops/cnn/conv/unary.rs
Expand Up @@ -81,7 +81,7 @@ impl ConvUnary {

fn kernel_as_packed_as<T: Datum + Copy + Zero>(
&self,
packer: &PackA<T>,
packer: &PackA,
) -> TractResult<ArrayD<Arc<Tensor>>> {
let kernel = self.kernel_as_group_o_ihw()?;
let kernel = kernel.to_array_view::<T>()?;
Expand Down Expand Up @@ -232,12 +232,14 @@ impl ConvUnary {
self.group,
c_dim / self.group,
mmm.as_mmm().b_pack(),
self.q_params
.as_ref()
.and_then(|q| q.zero_point_b.as_ref())
.map(|t| t.to_scalar::<TB>().map(|x| *x))
.transpose()?
.unwrap_or(TB::default()),
tensor0(
self.q_params
.as_ref()
.and_then(|q| q.zero_point_b.as_ref())
.map(|t| t.to_scalar::<TB>().map(|x| *x))
.transpose()?
.unwrap_or(TB::zero()),
),
)?,
&[wire],
)?[0];
Expand All @@ -263,7 +265,7 @@ impl ConvUnary {
bc_c_shape: output_shape.shape.clone(),
c_fact: TypedFact::dt_shape(TC::datum_type(), &*output_shape.shape)?,
c_prefix_dim_and_stride,
packed_as: self.kernel_as_packed_as(&mmm.as_mmm().a_pack())?,
packed_as: self.kernel_as_packed_as::<TA>(&mmm.as_mmm().a_pack())?,
fused_ops: self.bias_as_non_linear()?,
mmm,
},
Expand Down
64 changes: 26 additions & 38 deletions core/src/ops/matmul/pack_b.rs
@@ -1,35 +1,24 @@
use num_traits::Zero;

use crate::internal::*;
use ndarray::*;

use tract_linalg::frame::PackB;

#[derive(Debug, Clone, PartialEq, Educe)]
#[educe(Hash)]
pub struct MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
pub(crate) pack_b: PackB<T>,
pub struct MatMatMulPackB {
pub(crate) pack_b: PackB,
pub(crate) row_stride: isize,
pub(crate) col_stride: isize,
pub(crate) output_shape: TVec<usize>,
}

impl<T> DynHash for MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
impl DynHash for MatMatMulPackB {
fn dyn_hash(&self, hasher: &mut dyn std::hash::Hasher) {
dyn_hash(&self, hasher)
}
}

impl<T> Op for MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
impl Op for MatMatMulPackB {
fn name(&self) -> Cow<str> {
"MatMatMulPackB".into()
}
Expand All @@ -42,43 +31,42 @@ where
op_as_typed_op!();
}

impl<T> EvalOp for MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
impl MatMatMulPackB {
unsafe fn pack_t<T: Datum + Copy>(&self, b: &Tensor, packed: &mut Tensor) {
let b_prefix = &b.shape()[..b.shape().len() - 2];
let b = b.to_array_view_unchecked::<T>();
for prefix in indices(b_prefix).into_iter() {
let mut b = b.view();
let mut p = packed.to_array_view_mut_unchecked();
for &dim in prefix.slice() {
b.index_axis_inplace(Axis(0), dim);
p.index_axis_inplace(Axis(0), dim);
}
self.pack_b.pack(p.as_mut_ptr(), b.as_ptr(), self.row_stride, self.col_stride)
}
}
}

impl EvalOp for MatMatMulPackB {
fn is_stateless(&self) -> bool {
true
}

fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
let b = args_1!(inputs);
let dt = b.datum_type();
let mut packed = unsafe {
Tensor::uninitialized_aligned::<T>(&*self.output_shape, self.pack_b.alignment())
Tensor::uninitialized_aligned_dt(dt, &*self.output_shape, self.pack_b.alignment())
.unwrap()
};
if b.shape()[..b.shape().len() - 2].iter().any(|d| *d > 1) {
let b = b.to_array_view::<T>()?;
let b_prefix = &b.shape()[..b.shape().len() - 2];
for prefix in indices(b_prefix).into_iter() {
let mut b = b.view();
let mut p = packed.to_array_view_mut()?;
for &dim in prefix.slice() {
b.index_axis_inplace(Axis(0), dim);
p.index_axis_inplace(Axis(0), dim);
}
self.pack_b.pack(p.as_mut_ptr(), b.as_ptr(), self.row_stride, self.col_stride)
}
} else {
self.pack_b.pack(packed.as_ptr_mut()?, b.as_ptr()?, self.row_stride, self.col_stride)
unsafe {
dispatch_copy_by_size!(Self::pack_t(dt)(self, &b, &mut packed));
}
Ok(tvec!(packed.into_arc_tensor()))
}
}

impl<T> TypedOp for MatMatMulPackB<T>
where
T: Copy + Datum + Zero,
{
impl TypedOp for MatMatMulPackB {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(TypedFact::dt_shape(inputs[0].datum_type, &*self.output_shape)?))
}
Expand Down

0 comments on commit 08da9d6

Please sign in to comment.