Skip to content

Commit

Permalink
less explicit types
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 12, 2020
1 parent 8e8a63f commit a3d707c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
20 changes: 12 additions & 8 deletions linalg/src/frame/mmm/tests.rs
Expand Up @@ -32,7 +32,7 @@ macro_rules! mmm_frame_tests {
}

#[test]
fn conv_prepacked_prop(pb in strat_conv_1d()) {
fn conv_prepacked_prop(pb in strat_conv_1d::<$ta, $tb>()) {
if $cond {
let found = pb.run::<$ker, $tc, $ti>();
let expected = pb.expected::<$tc, $ti>();
Expand Down Expand Up @@ -773,6 +773,8 @@ macro_rules! qmmm_frame_tests {
use $crate::frame::mmm::tests::*;
use $crate::frame::mmm::QuantizedParam;

type QProblem = QMatMulProblem<$ta, $tb, $tc, $ti>;

proptest::proptest! {
#[test]
fn q_mat_mul_prop(pb in any::<QMatMulProblem<$ta, $tb, $tc, $ti>>()) {
Expand All @@ -785,7 +787,7 @@ macro_rules! qmmm_frame_tests {
#[test]
fn q_mat_mul_1() {
if $cond {
let pb = QMatMulProblem {
let pb = QProblem {
m: 1,
k: 1,
n: 1,
Expand All @@ -802,7 +804,7 @@ macro_rules! qmmm_frame_tests {
#[test]
fn q_mat_mul_sat_1() {
if $cond {
let pb = QMatMulProblem {
let pb = QProblem {
m: 1,
k: 1,
n: 1,
Expand All @@ -819,7 +821,7 @@ macro_rules! qmmm_frame_tests {

fn q_mat_mul_sat_2() {
if $cond {
let pb = QMatMulProblem {
let pb = QProblem {
m: 1,
k: 1,
n: 1,
Expand All @@ -836,7 +838,7 @@ macro_rules! qmmm_frame_tests {
#[test]
fn q_mat_mul_n2() {
if $cond {
let pb = QMatMulProblem {
let pb = QProblem {
m: 1,
k: 1,
n: 2,
Expand All @@ -853,7 +855,7 @@ macro_rules! qmmm_frame_tests {
#[test]
fn q_mat_mul_k2() {
if $cond {
let pb = QMatMulProblem {
let pb = QProblem {
m: 1,
k: 2,
n: 1,
Expand All @@ -879,10 +881,12 @@ macro_rules! qmmm_s_frame_tests {
use $crate::frame::mmm::tests::*;
use $crate::frame::mmm::QuantizedParam;

type QProblem = QMatMulProblem<$ta, $tb, $tc, $ti>;

#[test]
fn q_mat_mul_1_1_5() {
if $cond {
let pb = QMatMulProblem {
let pb = QProblem {
m: 1,
k: 1,
n: 5,
Expand All @@ -899,7 +903,7 @@ macro_rules! qmmm_s_frame_tests {
#[test]
fn q_mat_mul_1_1_1() {
if $cond {
let pb = QMatMulProblem {
let pb = QProblem {
m: 1,
k: 1,
n: 1,
Expand Down
10 changes: 9 additions & 1 deletion linalg/src/generic/mmm.rs
Expand Up @@ -468,6 +468,8 @@ where
let mut ab = [[TI::zero(); 2]; 3];
match (*spec.a, *spec.b, *spec.linear) {
(Packed { ptr: a }, Packed { ptr: b }, Mul { k }) => {
let a = a as *const TA;
let b = b as *const TB;
for i in 0..k {
let a = std::slice::from_raw_parts(a.offset(3 * i as isize), 3);
let b = std::slice::from_raw_parts(b.offset(2 * i as isize), 2);
Expand All @@ -480,6 +482,8 @@ where
}
}
(Packed { ptr: a }, OffsetsAndPtrs { row_byte_offsets, col_ptrs }, Mul { k }) => {
let a = a as *const TA;
let col_ptrs = col_ptrs as *const *const TB;
let pb0 = *(col_ptrs.offset(0));
let pb1 = *(col_ptrs.offset(1));
for i in 0..k {
Expand All @@ -497,6 +501,8 @@ where
}
}
(Packed { ptr: a }, VecStride { ptr: b, byte_stride, .. }, Mul { k }) => {
let a = a as *const TA;
let b = b as *const TB;
for i in 0..k {
let a = std::slice::from_raw_parts(a.offset(3 * i as isize), 3);
let b = *b
Expand All @@ -517,9 +523,10 @@ where
FusedKerSpec::Done => break,
FusedKerSpec::AddC => match *spec.c {
Strides { ptr: c, row_byte_stride, col_byte_stride, .. } => {
let c = c as *const TC;
let rsc = row_byte_stride as usize / std::mem::size_of::<TC>();
let csc = col_byte_stride as usize / std::mem::size_of::<TC>();
let c = std::slice::from_raw_parts_mut(c, 1 + 1 * csc + 2 * rsc);
let c = std::slice::from_raw_parts(c, 1 + 1 * csc + 2 * rsc);
ab[0][0] += c[0 * csc + 0 * rsc].as_();
ab[0][1] += c[1 * csc + 0 * rsc].as_();
ab[1][0] += c[0 * csc + 1 * rsc].as_();
Expand Down Expand Up @@ -609,6 +616,7 @@ where
}
match *spec.c {
Strides { ptr: c, row_byte_stride, col_byte_stride, .. } => {
let c = c as *mut TC;
let rsc = row_byte_stride as usize / std::mem::size_of::<TC>();
let csc = col_byte_stride as usize / std::mem::size_of::<TC>();
let c = std::slice::from_raw_parts_mut(c, 1 + 3 * csc + 3 * rsc);
Expand Down

0 comments on commit a3d707c

Please sign in to comment.