diff --git a/src/backends.rs b/src/backends.rs index 9a22181..27b8d2a 100644 --- a/src/backends.rs +++ b/src/backends.rs @@ -71,7 +71,9 @@ use crate::prelude::*; use paste::paste; macro_rules! impl_operations { - ($_t:ident $($name: ident $($op: ident ($($generics: tt)*) ($($generics_use: tt)*) ($($arg: ident : $arg_ty: ty),*) where ($($where_ty:ty : $implements: path),*) -> $t: ty),*);*;) => { + ($_t:ident $($name: ident $($op: ident ($($generics: tt)*) ($($generics_use: tt)*) ($($arg: ident : $arg_ty: ty),*) + where ($($where_ty:ty : $implements: path),*) -> $t: ty),*);*;) => { + pub trait Backend<$_t>: Default{ $($( fn $op<$($generics)*>(&self, $($arg : $arg_ty),*) -> paste!( >::[<$op:camel Output>] ) @@ -109,7 +111,7 @@ impl_operations!(T MatrixMul matrix_mul(A: StaticVec, B: StaticVec, C: StaticVec, const ALEN: usize, const BLEN: usize, const CLEN: usize) (A, B, C, ALEN, BLEN, CLEN) - (a: &A, b: &B, buffer: &mut C, m: usize, n: usize, k: usize) + (a: &A, b: &B, buffer: &mut C, m: usize, n: usize, k: usize, a_trans: bool, b_trans: bool) where ( A: Sized, B: Sized, diff --git a/src/backends/blas.rs b/src/backends/blas.rs index 62acc61..660b98f 100644 --- a/src/backends/blas.rs +++ b/src/backends/blas.rs @@ -80,16 +80,19 @@ macro_rules! impl_gemm { m: usize, n: usize, k: usize, + a_trans: bool, + b_trans: bool, ) where A: Sized, B: Sized, { + use cblas_sys::CBLAS_TRANSPOSE::*; unsafe { // TODO: gemv should be used here when other's dimensions are a transpose of self. cblas_sys::$f( cblas_sys::CBLAS_LAYOUT::CblasRowMajor, - cblas_sys::CBLAS_TRANSPOSE::CblasNoTrans, - cblas_sys::CBLAS_TRANSPOSE::CblasNoTrans, + if a_trans { CblasTrans } else { CblasNoTrans }, + if b_trans { CblasTrans } else { CblasNoTrans }, m as i32, n as i32, k as i32, diff --git a/src/tensor.rs b/src/tensor.rs index 7309693..ecf2323 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,6 +1,6 @@ use crate::{backends::*, prelude::*}; use paste::paste; -use std::hint::unreachable_unchecked; +use std::mem::transmute; /// Tensor shape with static dimensions but with optionally dynamic shape. /// To achive a static shape the trait should be const implemented. @@ -49,7 +49,7 @@ impl Shape for [usize] { fn slice(&self) -> &[usize; LEN] { assert_eq!(self.len(), LEN); - unsafe { std::mem::transmute(self.as_ptr()) } + unsafe { transmute(self.as_ptr()) } } } @@ -61,7 +61,7 @@ impl const Shape<2> for MatrixShape { match n { 0 => K, 1 => M, - _ => unsafe { unreachable_unchecked() }, + _ => panic!("Cannot get len of axis higher than 1, as a matrix only has 2 axies (rows and columns)"), } } fn volume(&self) -> usize { @@ -174,13 +174,13 @@ macro_rules! impl_index_slice { assert!(i < self.shape.axis_len(0)); unsafe { - std::mem::transmute::<*const T, &'a $($mut)? [T; LEN]>( + transmute::<*const T, &'a $($mut)? [T; LEN]>( self.data .[< as $(_$mut)? _ptr>]() .add(i * (self.shape.volume() / self.shape.axis_len(NDIM - 1))), ) .[]( - std::mem::transmute::<*const usize, &[usize; NDIM - 1]>( + transmute::<*const usize, &[usize; NDIM - 1]>( self.shape.slice()[0..NDIM - 1].as_ptr(), ), B::default(), @@ -227,6 +227,8 @@ impl< m, n, k, + false, + false, ); }