Skip to content

Commit

Permalink
Added transmute_elements function for StaticVecUnion and [WIP] ma…
Browse files Browse the repository at this point in the history
…trix transpose opereration
  • Loading branch information
unic0rn9k committed Feb 23, 2022
1 parent e237955 commit 1e4fb97
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ impl_operations!(T
C: Sized,
T: Copy
) -> ();

Transpose
transpose_inplace(const LEN: usize)()(a: &mut impl StaticVec<T, LEN>) where () -> (),
transpose(const LEN: usize)()(a: &impl StaticVec<T, LEN>, buffer: &mut impl StaticVec<T, LEN>) where () -> ();
);

/// Perform opertaions on a [`StaticVec`] with a static backend.
Expand Down
30 changes: 25 additions & 5 deletions src/backends/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#[derive(Default)]
pub struct Rust;
use super::*;
use operations::*;
use std::simd::Simd;

// TODO: This needs to check if SIMD is available at compile time.
Expand All @@ -15,7 +16,7 @@ macro_rules! impl_dot {
/// use slas::prelude::*;
/// assert!(slas_backend::Rust.dot(&[1., 2., 3.], &moo![f32: -1, 2, -1]) == 0.);
/// ```
impl operations::DotProduct<$t> for Rust {
impl DotProduct<$t> for Rust {
fn dot<const LEN: usize>(
&self,
a: &impl StaticVec<$t, LEN>,
Expand All @@ -41,20 +42,20 @@ macro_rules! impl_dot {

macro_rules! impl_norm {
($t: ty) => {
impl operations::Normalize<$t> for Rust {
impl Normalize<$t> for Rust {
type NormOutput = $t;
fn norm<const LEN: usize>(&self, a: &impl StaticVec<$t, LEN>) -> $t {
//TODO: Use hypot function here. This will require implementing hypot for all float types first.
a.moo_ref().iter().map(|&n| n * n).sum::<$t>().sqrt_()
}

fn normalize<const LEN: usize>(&self, a: &mut impl StaticVec<$t, LEN>) {
let norm = operations::Normalize::norm(self, a);
let norm = Normalize::norm(self, a);
a.mut_moo_ref().iter_mut().for_each(|n| *n /= norm);
}
}

impl operations::Normalize<Complex<$t>> for Rust {
impl Normalize<Complex<$t>> for Rust {
type NormOutput = $t;
fn norm<const LEN: usize>(&self, a: &impl StaticVec<Complex<$t>, LEN>) -> $t {
//TODO: Use hypot function here. This will require implementing hypot for all float types first.
Expand All @@ -67,7 +68,7 @@ macro_rules! impl_norm {
}

fn normalize<const LEN: usize>(&self, a: &mut impl StaticVec<Complex<$t>, LEN>) {
let norm = operations::Normalize::norm(self, a);
let norm = Normalize::norm(self, a);
a.mut_moo_ref()
.iter_mut()
.for_each(|n| *n = *n / norm.into());
Expand All @@ -76,6 +77,25 @@ macro_rules! impl_norm {
};
}

//macro_rules! impl_transpose {
// ($t: ty) => {
//
// };
//}

impl Transpose<f32> for Rust {
fn transpose_inplace<const LEN: usize>(&self, _a: &mut impl StaticVec<f32, LEN>) -> () {
todo!()
}

fn transpose<const LEN: usize>(
&self,
_a: &impl StaticVec<f32, LEN>,
_buffer: &mut impl StaticVec<f32, LEN>,
) -> () {
}
}

impl_norm!(f32);
impl_norm!(f64);

Expand Down
12 changes: 11 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ pub mod tensor;
pub mod backends;
pub use num;

use std::{mem::transmute, ops::*};
use std::{
mem::{size_of, transmute},
ops::*,
};
#[cfg(feature = "blis-sys")]
extern crate blis_src;
extern crate cblas_sys;
Expand All @@ -215,6 +218,13 @@ impl<'a, T: Copy, const LEN: usize> StaticVecUnion<'a, T, LEN> {
pub fn slice(&'a self) -> &'a [T; LEN] {
unsafe { &*(self.as_ptr() as *const [T; LEN]) }
}

pub const unsafe fn transmute_elements<U: Copy>(&'a self) -> &'a StaticVecUnion<'a, U, LEN> {
if size_of::<T>() == size_of::<U>() {
panic!("Cannot transmute between vectors of different sizes")
}
transmute(self)
}
}

impl<'a, T: Copy + PartialEq, const LEN: usize> std::cmp::PartialEq<StaticVecUnion<'a, T, LEN>>
Expand Down

0 comments on commit 1e4fb97

Please sign in to comment.