Skip to content

Commit

Permalink
[WIP] Allowed lazy transposing in matrix multiplication\n(breaking ch…
Browse files Browse the repository at this point in the history
…ange)
  • Loading branch information
unic0rn9k committed Feb 6, 2022
1 parent af6f9ea commit 87c66a9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/backends.rs
Expand Up @@ -71,7 +71,9 @@ use crate::prelude::*;
use paste::paste;

macro_rules! impl_operations {
($_t:ident $($name: ident $($op: ident ($($generics: tt)*) ($($generics_use: tt)*) ($($arg: ident : $arg_ty: ty),*) where ($($where_ty:ty : $implements: path),*) -> $t: ty),*);*;) => {
($_t:ident $($name: ident $($op: ident ($($generics: tt)*) ($($generics_use: tt)*) ($($arg: ident : $arg_ty: ty),*)
where ($($where_ty:ty : $implements: path),*) -> $t: ty),*);*;) => {

pub trait Backend<$_t>: Default{
$($(
fn $op<$($generics)*>(&self, $($arg : $arg_ty),*) -> paste!( <Self as operations::$name<$_t>>::[<$op:camel Output>] )
Expand Down Expand Up @@ -109,7 +111,7 @@ impl_operations!(T
MatrixMul
matrix_mul(A: StaticVec<T, ALEN>, B: StaticVec<T, BLEN>, C: StaticVec<T, CLEN>, const ALEN: usize, const BLEN: usize, const CLEN: usize)
(A, B, C, ALEN, BLEN, CLEN)
(a: &A, b: &B, buffer: &mut C, m: usize, n: usize, k: usize)
(a: &A, b: &B, buffer: &mut C, m: usize, n: usize, k: usize, a_trans: bool, b_trans: bool)
where (
A: Sized,
B: Sized,
Expand Down
7 changes: 5 additions & 2 deletions src/backends/blas.rs
Expand Up @@ -80,16 +80,19 @@ macro_rules! impl_gemm {
m: usize,
n: usize,
k: usize,
a_trans: bool,
b_trans: bool,
) where
A: Sized,
B: Sized,
{
use cblas_sys::CBLAS_TRANSPOSE::*;
unsafe {
// TODO: gemv should be used here when other's dimensions are a transpose of self.
cblas_sys::$f(
cblas_sys::CBLAS_LAYOUT::CblasRowMajor,
cblas_sys::CBLAS_TRANSPOSE::CblasNoTrans,
cblas_sys::CBLAS_TRANSPOSE::CblasNoTrans,
if a_trans { CblasTrans } else { CblasNoTrans },
if b_trans { CblasTrans } else { CblasNoTrans },
m as i32,
n as i32,
k as i32,
Expand Down
12 changes: 7 additions & 5 deletions src/tensor.rs
@@ -1,6 +1,6 @@
use crate::{backends::*, prelude::*};
use paste::paste;
use std::hint::unreachable_unchecked;
use std::mem::transmute;

/// Tensor shape with static dimensions but with optionally dynamic shape.
/// To achive a static shape the trait should be const implemented.
Expand Down Expand Up @@ -49,7 +49,7 @@ impl<const LEN: usize> Shape<LEN> for [usize] {

fn slice(&self) -> &[usize; LEN] {
assert_eq!(self.len(), LEN);
unsafe { std::mem::transmute(self.as_ptr()) }
unsafe { transmute(self.as_ptr()) }
}
}

Expand All @@ -61,7 +61,7 @@ impl<const M: usize, const K: usize> const Shape<2> for MatrixShape<M, K> {
match n {
0 => K,
1 => M,
_ => unsafe { unreachable_unchecked() },
_ => panic!("Cannot get len of axis higher than 1, as a matrix only has 2 axies (rows and columns)"),
}
}
fn volume(&self) -> usize {
Expand Down Expand Up @@ -174,13 +174,13 @@ macro_rules! impl_index_slice {
assert!(i < self.shape.axis_len(0));

unsafe {
std::mem::transmute::<*const T, &'a $($mut)? [T; LEN]>(
transmute::<*const T, &'a $($mut)? [T; LEN]>(
self.data
.[< as $(_$mut)? _ptr>]()
.add(i * (self.shape.volume() / self.shape.axis_len(NDIM - 1))),
)
.[<reshape_unchecked_ref $(_$mut)? >](
std::mem::transmute::<*const usize, &[usize; NDIM - 1]>(
transmute::<*const usize, &[usize; NDIM - 1]>(
self.shape.slice()[0..NDIM - 1].as_ptr(),
),
B::default(),
Expand Down Expand Up @@ -227,6 +227,8 @@ impl<
m,
n,
k,
false,
false,
);
}

Expand Down

0 comments on commit 87c66a9

Please sign in to comment.