Skip to content

Commit

Permalink
Merge pull request #341 from rust-ndarray/lax-eigh-generalized-work
Browse files Browse the repository at this point in the history
Merge `Eigh_` into `Lapack` trait, add working memory management
  • Loading branch information
termoshtt committed Sep 25, 2022
2 parents 33e2dc3 + bef1083 commit f672b07
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 122 deletions.
248 changes: 129 additions & 119 deletions lax/src/eigh.rs
@@ -1,180 +1,190 @@
//! Eigenvalue problem for symmetric/Hermite matricies
//!
//! LAPACK correspondance
//! ----------------------
//!
//! | f32 | f64 | c32 | c64 |
//! |:------|:------|:------|:------|
//! | ssyev | dsyev | cheev | zheev |

use super::*;
use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};

#[cfg_attr(doc, katexit::katexit)]
/// Eigenvalue problem for symmetric/hermite matrix
pub trait Eigh_: Scalar {
/// Compute right eigenvalue and eigenvectors $Ax = \lambda x$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:------|:------|:------|:------|
/// | ssyev | dsyev | cheev | zheev |
///
fn eigh(
calc_eigenvec: bool,
layout: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
) -> Result<Vec<Self::Real>>;
pub struct EighWork<T: Scalar> {
pub n: i32,
pub jobz: JobEv,
pub eigs: Vec<MaybeUninit<T::Real>>,
pub work: Vec<MaybeUninit<T>>,
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
}

/// Compute generalized right eigenvalue and eigenvectors $Ax = \lambda B x$
///
/// LAPACK correspondance
/// ----------------------
///
/// | f32 | f64 | c32 | c64 |
/// |:------|:------|:------|:------|
/// | ssygv | dsygv | chegv | zhegv |
///
fn eigh_generalized(
calc_eigenvec: bool,
layout: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
b: &mut [Self],
) -> Result<Vec<Self::Real>>;
pub trait EighWorkImpl: Sized {
type Elem: Scalar;
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self>;
fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem])
-> Result<&[<Self::Elem as Scalar>::Real]>;
fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Vec<<Self::Elem as Scalar>::Real>>;
}

macro_rules! impl_eigh {
(@real, $scalar:ty, $ev:path, $evg:path) => {
impl_eigh!(@body, $scalar, $ev, $evg, );
};
(@complex, $scalar:ty, $ev:path, $evg:path) => {
impl_eigh!(@body, $scalar, $ev, $evg, rwork);
};
(@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => {
impl Eigh_ for $scalar {
fn eigh(
calc_v: bool,
layout: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
) -> Result<Vec<Self::Real>> {
macro_rules! impl_eigh_work_c {
($c:ty, $ev:path) => {
impl EighWorkImpl for EighWork<$c> {
type Elem = $c;

fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { JobEv::All } else { JobEv::None };
let mut eigs: Vec<MaybeUninit<Self::Real>> = vec_uninit(n as usize);

$(
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(3 * n as usize - 2 as usize);
)*

// calc work size
let jobz = if calc_eigenvectors {
JobEv::All
} else {
JobEv::None
};
let mut eigs = vec_uninit(n as usize);
let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
let mut info = 0;
let mut work_size = [Self::zero()];
let mut work_size = [Self::Elem::zero()];
unsafe {
$ev(
jobz.as_ptr() ,
uplo.as_ptr(),
jobz.as_ptr(),
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
&n,
AsPtr::as_mut_ptr(a),
std::ptr::null_mut(),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
AsPtr::as_mut_ptr(&mut rwork),
&mut info,
);
}
info.as_lapack_result()?;

// actual ev
let lwork = work_size[0].to_usize().unwrap();
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
let lwork = lwork as i32;
let work = vec_uninit(lwork);
Ok(EighWork {
n,
eigs,
jobz,
work,
rwork: Some(rwork),
})
}

fn calc(
&mut self,
uplo: UPLO,
a: &mut [Self::Elem],
) -> Result<&[<Self::Elem as Scalar>::Real]> {
let lwork = self.work.len().to_i32().unwrap();
let mut info = 0;
unsafe {
$ev(
jobz.as_ptr(),
self.jobz.as_ptr(),
uplo.as_ptr(),
&n,
&self.n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(&mut work),
&self.n,
AsPtr::as_mut_ptr(&mut self.eigs),
AsPtr::as_mut_ptr(&mut self.work),
&lwork,
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
&mut info,
);
}
info.as_lapack_result()?;

let eigs = unsafe { eigs.assume_init() };
Ok(eigs)
Ok(unsafe { self.eigs.slice_assume_init_ref() })
}

fn eigh_generalized(
calc_v: bool,
layout: MatrixLayout,
fn eval(
mut self,
uplo: UPLO,
a: &mut [Self],
b: &mut [Self],
) -> Result<Vec<Self::Real>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { JobEv::All } else { JobEv::None };
let mut eigs: Vec<MaybeUninit<Self::Real>> = vec_uninit(n as usize);
a: &mut [Self::Elem],
) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
let _eig = self.calc(uplo, a)?;
Ok(unsafe { self.eigs.assume_init() })
}
}
};
}
impl_eigh_work_c!(c64, lapack_sys::zheev_);
impl_eigh_work_c!(c32, lapack_sys::cheev_);

$(
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(3 * n as usize - 2);
)*
macro_rules! impl_eigh_work_r {
($f:ty, $ev:path) => {
impl EighWorkImpl for EighWork<$f> {
type Elem = $f;

// calc work size
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_eigenvectors {
JobEv::All
} else {
JobEv::None
};
let mut eigs = vec_uninit(n as usize);
let mut info = 0;
let mut work_size = [Self::zero()];
let mut work_size = [Self::Elem::zero()];
unsafe {
$evg(
&1, // ITYPE A*x = (lambda)*B*x
$ev(
jobz.as_ptr(),
uplo.as_ptr(),
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
&n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(b),
std::ptr::null_mut(),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
&mut info,
);
}
info.as_lapack_result()?;

// actual evg
let lwork = work_size[0].to_usize().unwrap();
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
let lwork = lwork as i32;
let work = vec_uninit(lwork);
Ok(EighWork {
n,
eigs,
jobz,
work,
rwork: None,
})
}

fn calc(
&mut self,
uplo: UPLO,
a: &mut [Self::Elem],
) -> Result<&[<Self::Elem as Scalar>::Real]> {
let lwork = self.work.len().to_i32().unwrap();
let mut info = 0;
unsafe {
$evg(
&1, // ITYPE A*x = (lambda)*B*x
jobz.as_ptr(),
$ev(
self.jobz.as_ptr(),
uplo.as_ptr(),
&n,
&self.n,
AsPtr::as_mut_ptr(a),
&n,
AsPtr::as_mut_ptr(b),
&n,
AsPtr::as_mut_ptr(&mut eigs),
AsPtr::as_mut_ptr(&mut work),
&self.n,
AsPtr::as_mut_ptr(&mut self.eigs),
AsPtr::as_mut_ptr(&mut self.work),
&lwork,
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
&mut info,
);
}
info.as_lapack_result()?;
let eigs = unsafe { eigs.assume_init() };
Ok(eigs)
Ok(unsafe { self.eigs.slice_assume_init_ref() })
}

fn eval(
mut self,
uplo: UPLO,
a: &mut [Self::Elem],
) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
let _eig = self.calc(uplo, a)?;
Ok(unsafe { self.eigs.assume_init() })
}
}
};
} // impl_eigh!

impl_eigh!(@real, f64, lapack_sys::dsyev_, lapack_sys::dsygv_);
impl_eigh!(@real, f32, lapack_sys::ssyev_, lapack_sys::ssygv_);
impl_eigh!(@complex, c64, lapack_sys::zheev_, lapack_sys::zhegv_);
impl_eigh!(@complex, c32, lapack_sys::cheev_, lapack_sys::chegv_);
}
impl_eigh_work_r!(f64, lapack_sys::dsyev_);
impl_eigh_work_r!(f32, lapack_sys::ssyev_);

0 comments on commit f672b07

Please sign in to comment.