From dfa19d7cb8cef19999410aed825ff4c1fd05dfc5 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 2 Jun 2017 15:26:12 +0900 Subject: [PATCH 01/17] Add QR_ --- src/impl2/mod.rs | 2 ++ src/impl2/opnorm.rs | 2 +- src/impl2/qr.rs | 55 +++++++++++++++++++++++++++++++++++++++++++++ src/layout.rs | 24 ++++++++++++++++++++ 4 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 src/impl2/qr.rs diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 89a38823..2e9b2f21 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -1,6 +1,8 @@ pub mod opnorm; +pub mod qr; pub use self::opnorm::*; +pub use self::qr::*; pub trait LapackScalar: OperatorNorm_ {} impl LapackScalar for A where A: OperatorNorm_ {} diff --git a/src/impl2/opnorm.rs b/src/impl2/opnorm.rs index e2687ebf..b1efd9f7 100644 --- a/src/impl2/opnorm.rs +++ b/src/impl2/opnorm.rs @@ -4,7 +4,7 @@ use lapack::c; use lapack::c::Layout::ColumnMajor as cm; use types::*; -use layout::*; +use layout::Layout; #[repr(u8)] pub enum NormType { diff --git a/src/impl2/qr.rs b/src/impl2/qr.rs new file mode 100644 index 00000000..db8f21d9 --- /dev/null +++ b/src/impl2/qr.rs @@ -0,0 +1,55 @@ +//! Implement QR decomposition + +use std::cmp::min; +use num_traits::Zero; +use lapack::c; + +use types::*; +use error::*; +use layout::Layout; + +pub trait QR_: Sized { + fn householder(Layout, a: &mut [Self]) -> Result>; + fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>; + fn qr(Layout, a: &mut [Self]) -> Result>; +} + +macro_rules! impl_qr { + ($scalar:ty, $qrf:path, $gqr:path) => { +impl QR_ for $scalar { + fn householder(l: Layout, mut a: &mut [Self]) -> Result> { + let (row, col) = l.size(); + let k = min(row, col); + let mut tau = vec![Self::zero(); k as usize]; + let info = $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau); + if info == 0 { + Ok(tau) + } else { + Err(LapackError::new(info).into()) + } + } + + fn q(l: Layout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { + let (row, col) = l.size(); + let k = min(row, col); + let info = $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau); + if info == 0 { + Ok(()) + } else { + Err(LapackError::new(info).into()) + } + } + + fn qr(l: Layout, mut a: &mut [Self]) -> Result> { + let tau = Self::householder(l, a)?; + let r = Vec::from(&*a); + Self::q(l, a, &tau)?; + Ok(r) + } +} +}} // endmacro + +impl_qr!(f64, c::dgeqrf, c::dorgqr); +impl_qr!(f32, c::sgeqrf, c::sorgqr); +impl_qr!(c64, c::zgeqrf, c::zungqr); +impl_qr!(c32, c::cgeqrf, c::cungqr); diff --git a/src/layout.rs b/src/layout.rs index 66e56c19..f1fb8eb1 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -1,5 +1,6 @@ use ndarray::*; +use lapack::c; use super::error::*; @@ -7,6 +8,7 @@ pub type LDA = i32; pub type Col = i32; pub type Row = i32; +#[derive(Debug, Clone, Copy)] pub enum Layout { C((Row, LDA)), F((Col, LDA)), @@ -19,6 +21,28 @@ impl Layout { Layout::F((col, lda)) => (lda, col), } } + + pub fn row(&self) -> Row { + self.size().0 + } + + pub fn col(&self) -> Col { + self.size().1 + } + + pub fn lda(&self) -> LDA { + match *self { + Layout::C((_, lda)) => lda, + Layout::F((_, lda)) => lda, + } + } + + pub fn lapacke_layout(&self) -> c::Layout { + match *self { + Layout::C(_) => c::Layout::RowMajor, + Layout::F(_) => c::Layout::ColumnMajor, + } + } } pub trait AllocatedArray { From bafa1e6dbc6063ebf7ebbc8bc1e16003ec69a8a2 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 2 Jun 2017 15:29:53 +0900 Subject: [PATCH 02/17] Remove row/col --- src/layout.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/layout.rs b/src/layout.rs index f1fb8eb1..e948f63e 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -22,14 +22,6 @@ impl Layout { } } - pub fn row(&self) -> Row { - self.size().0 - } - - pub fn col(&self) -> Col { - self.size().1 - } - pub fn lda(&self) -> LDA { match *self { Layout::C((_, lda)) => lda, From d162408428f441374901f5c8d7a27976822df16a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 2 Jun 2017 17:56:24 +0900 Subject: [PATCH 03/17] Add QR implementation (without slice drop) --- src/impl2/mod.rs | 4 ++-- src/layout.rs | 22 ++++++++++++++++++++++ src/traits.rs | 16 ++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 2e9b2f21..1feca0dd 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -4,5 +4,5 @@ pub mod qr; pub use self::opnorm::*; pub use self::qr::*; -pub trait LapackScalar: OperatorNorm_ {} -impl LapackScalar for A where A: OperatorNorm_ {} +pub trait LapackScalar: OperatorNorm_ + QR_ {} +impl LapackScalar for A where A: OperatorNorm_ + QR_ {} diff --git a/src/layout.rs b/src/layout.rs index e948f63e..522af18c 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -44,6 +44,10 @@ pub trait AllocatedArray { fn as_allocated(&self) -> Result<&[Self::Scalar]>; } +pub trait AllocatedArrayMut: AllocatedArray { + fn as_allocated_mut(&mut self) -> Result<&mut [Self::Scalar]>; +} + impl AllocatedArray for ArrayBase where S: Data { @@ -76,3 +80,21 @@ impl AllocatedArray for ArrayBase Ok(slice) } } + +impl AllocatedArrayMut for ArrayBase + where S: DataMut +{ + fn as_allocated_mut(&mut self) -> Result<&mut [A]> { + let slice = self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?; + Ok(slice) + } +} + +pub fn reconstruct(l: Layout, a: Vec) -> Result> + where S: DataOwned +{ + Ok(match l { + Layout::C((row, col)) => ArrayBase::from_shape_vec((row as usize, col as usize), a)?, + Layout::F((col, row)) => ArrayBase::from_shape_vec((row as usize, col as usize).f(), a)?, + }) +} diff --git a/src/traits.rs b/src/traits.rs index 5dada9c6..ba5ace99 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -34,3 +34,19 @@ impl OperationNorm for ArrayBase Ok(A::opnorm(t, l, a)) } } + +pub trait QR { + fn qr(self) -> Result<(Q, R)>; +} + +impl QR, ArrayBase> for ArrayBase + where A: LapackScalar, + Sq: DataMut, + Sr: DataOwned +{ + fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> { + let l = self.layout()?; + let r = A::qr(l, self.as_allocated_mut()?)?; + Ok((self, reconstruct(l, r)?)) + } +} From 65337f467484e3b5642736a3511683a18b8f316d Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 4 Jun 2017 14:26:45 +0900 Subject: [PATCH 04/17] Rename to qr2 temporally --- src/traits.rs | 8 +++++--- tests/qr.rs | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/traits.rs b/src/traits.rs index ba5ace99..4d771688 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -36,7 +36,7 @@ impl OperationNorm for ArrayBase } pub trait QR { - fn qr(self) -> Result<(Q, R)>; + fn qr2(self) -> Result<(Q, R)>; } impl QR, ArrayBase> for ArrayBase @@ -44,9 +44,11 @@ impl QR, ArrayBase> for ArrayBase, Sr: DataOwned { - fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> { + fn qr2(mut self) -> Result<(ArrayBase, ArrayBase)> { let l = self.layout()?; let r = A::qr(l, self.as_allocated_mut()?)?; - Ok((self, reconstruct(l, r)?)) + let r = reconstruct(l, r)?; + let q = self; + Ok((q, r)) } } diff --git a/tests/qr.rs b/tests/qr.rs index c298f30b..7387ea7e 100644 --- a/tests/qr.rs +++ b/tests/qr.rs @@ -10,7 +10,7 @@ fn $funcname() { let a = $random($n, $m, $t); let ans = a.clone(); println!("a = \n{:?}", &a); - let (q, r) = a.qr().unwrap(); + let (q, r) : (_, Array2) = a.qr2().unwrap(); println!("q = \n{:?}", &q); println!("r = \n{:?}", &r); assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7); From a83df318f23ccd581ebaa4ff62d7e3e2315f01f7 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 4 Jun 2017 16:30:13 +0900 Subject: [PATCH 05/17] Take slice of QR --- src/traits.rs | 27 ++++++++++++++++++++++----- tests/qr.rs | 2 +- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/traits.rs b/src/traits.rs index 4d771688..4fb567a2 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -2,6 +2,7 @@ pub use impl2::LapackScalar; pub use impl2::NormType; +use num_traits::Zero; use ndarray::*; use super::types::*; @@ -39,16 +40,32 @@ pub trait QR { fn qr2(self) -> Result<(Q, R)>; } -impl QR, ArrayBase> for ArrayBase - where A: LapackScalar, - Sq: DataMut, - Sr: DataOwned +impl QR, ArrayBase> for ArrayBase + where A: LapackScalar + Copy + Zero, + S: DataMut, + Sq: DataOwned + DataMut, + Sr: DataOwned + DataMut { fn qr2(mut self) -> Result<(ArrayBase, ArrayBase)> { + let n = self.rows(); + let m = self.cols(); + let k = ::std::cmp::min(n, m); let l = self.layout()?; + // calc QR decomposition let r = A::qr(l, self.as_allocated_mut()?)?; - let r = reconstruct(l, r)?; + let r: Array2<_> = reconstruct(l, r)?; let q = self; + // get slice + let qv = q.slice(s![..n as isize, ..k as isize]); + let mut q = unsafe { ArrayBase::uninitialized((n, k)) }; + q.assign(&qv); + let rv = r.slice(s![..k as isize, ..m as isize]); + let mut r = ArrayBase::zeros((k, m)); + for ((i, j), val) in r.indexed_iter_mut() { + if i <= j { + *val = rv[(i, j)]; + } + } Ok((q, r)) } } diff --git a/tests/qr.rs b/tests/qr.rs index 7387ea7e..079fccc8 100644 --- a/tests/qr.rs +++ b/tests/qr.rs @@ -10,7 +10,7 @@ fn $funcname() { let a = $random($n, $m, $t); let ans = a.clone(); println!("a = \n{:?}", &a); - let (q, r) : (_, Array2) = a.qr2().unwrap(); + let (q, r) : (Array2<_>, Array2<_>) = a.qr2().unwrap(); println!("q = \n{:?}", &q); println!("r = \n{:?}", &r); assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7); From 0a6f9594ab83822a2006ff14fc8ae50cb567d680 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 4 Jun 2017 16:31:55 +0900 Subject: [PATCH 06/17] Remove old QR --- src/matrix.rs | 42 ++---------------------------------------- src/traits.rs | 4 ++-- tests/qr.rs | 2 +- 3 files changed, 5 insertions(+), 43 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index ea6c5f79..21f530f2 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -6,12 +6,11 @@ use ndarray::DataMut; use lapack::c::Layout; use super::error::{LinalgError, StrideError}; -use super::impls::qr::ImplQR; use super::impls::svd::ImplSVD; use super::impls::solve::ImplSolve; -pub trait MFloat: ImplQR + ImplSVD + ImplSolve + NdFloat {} -impl MFloat for A {} +pub trait MFloat: ImplSVD + ImplSolve + NdFloat {} +impl MFloat for A {} /// Methods for general matrices pub trait Matrix: Sized { @@ -24,8 +23,6 @@ pub trait Matrix: Sized { fn layout(&self) -> Result; /// singular-value decomposition (SVD) fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>; - /// QR decomposition - fn qr(self) -> Result<(Self, Self), LinalgError>; /// LU decomposition fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>; /// permutate matrix (inplace) @@ -89,37 +86,6 @@ impl Matrix for Array { Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())), } } - fn qr(self) -> Result<(Self, Self), LinalgError> { - let (n, m) = self.size(); - let strides = self.strides(); - let k = min(n, m); - let layout = self.layout()?; - let (q, r) = ImplQR::qr(layout, m, n, self.clone().into_raw_vec())?; - let (qa, ra) = if strides[0] < strides[1] { - (Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(), - Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes()) - } else { - (Array::from_vec(q).into_shape((n, m)).unwrap(), Array::from_vec(r).into_shape((n, m)).unwrap()) - }; - let qm = if m > k { - let (qsl, _) = qa.view().split_at(Axis(1), k); - qsl.to_owned() - } else { - qa - }; - let mut rm = if n > k { - let (rsl, _) = ra.view().split_at(Axis(0), k); - rsl.to_owned() - } else { - ra - }; - for ((i, j), val) in rm.indexed_iter_mut() { - if i > j { - *val = A::zero(); - } - } - Ok((qm, rm)) - } fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { let (n, m) = self.size(); let k = min(n, m); @@ -167,10 +133,6 @@ impl Matrix for RcArray { let (u, s, v) = self.into_owned().svd()?; Ok((u.into_shared(), s.into_shared(), v.into_shared())) } - fn qr(self) -> Result<(Self, Self), LinalgError> { - let (q, r) = self.into_owned().qr()?; - Ok((q.into_shared(), r.into_shared())) - } fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { let (p, l, u) = self.into_owned().lu()?; Ok((p, l.into_shared(), u.into_shared())) diff --git a/src/traits.rs b/src/traits.rs index 4fb567a2..4e13a526 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -37,7 +37,7 @@ impl OperationNorm for ArrayBase } pub trait QR { - fn qr2(self) -> Result<(Q, R)>; + fn qr(self) -> Result<(Q, R)>; } impl QR, ArrayBase> for ArrayBase @@ -46,7 +46,7 @@ impl QR, ArrayBase> for ArrayBase + DataMut, Sr: DataOwned + DataMut { - fn qr2(mut self) -> Result<(ArrayBase, ArrayBase)> { + fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> { let n = self.rows(); let m = self.cols(); let k = ::std::cmp::min(n, m); diff --git a/tests/qr.rs b/tests/qr.rs index 079fccc8..3232782f 100644 --- a/tests/qr.rs +++ b/tests/qr.rs @@ -10,7 +10,7 @@ fn $funcname() { let a = $random($n, $m, $t); let ans = a.clone(); println!("a = \n{:?}", &a); - let (q, r) : (Array2<_>, Array2<_>) = a.qr2().unwrap(); + let (q, r) : (Array2<_>, Array2<_>) = a.qr().unwrap(); println!("q = \n{:?}", &q); println!("r = \n{:?}", &r); assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7); From a17ec9d8d9d4bcba0727b580ad9afc1212833fe6 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 4 Jun 2017 16:56:07 +0900 Subject: [PATCH 07/17] Split slice --- src/traits.rs | 67 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/src/traits.rs b/src/traits.rs index 4e13a526..a0165cbe 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -45,27 +45,68 @@ impl QR, ArrayBase> for ArrayBase, Sq: DataOwned + DataMut, Sr: DataOwned + DataMut +{ + fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> { + (&mut self).qr() + } +} + +fn take_slice(a: &ArrayBase, n: usize, m: usize) -> ArrayBase + where A: Copy, + S1: Data, + S2: DataMut + DataOwned +{ + let av = a.slice(s![..n as isize, ..m as isize]); + let mut a = unsafe { ArrayBase::uninitialized((n, m)) }; + a.assign(&av); + a +} + +fn take_slice_upper(a: &ArrayBase, n: usize, m: usize) -> ArrayBase + where A: Copy + Zero, + S1: Data, + S2: DataMut + DataOwned +{ + let av = a.slice(s![..n as isize, ..m as isize]); + let mut a = unsafe { ArrayBase::uninitialized((n, m)) }; + for ((i, j), val) in a.indexed_iter_mut() { + *val = if i <= j { av[(i, j)] } else { A::zero() }; + } + a +} + +impl<'a, A, S, Sq, Sr> QR, ArrayBase> for &'a mut ArrayBase + where A: LapackScalar + Copy + Zero, + S: DataMut, + Sq: DataOwned + DataMut, + Sr: DataOwned + DataMut { fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> { let n = self.rows(); let m = self.cols(); let k = ::std::cmp::min(n, m); let l = self.layout()?; - // calc QR decomposition let r = A::qr(l, self.as_allocated_mut()?)?; let r: Array2<_> = reconstruct(l, r)?; let q = self; - // get slice - let qv = q.slice(s![..n as isize, ..k as isize]); - let mut q = unsafe { ArrayBase::uninitialized((n, k)) }; - q.assign(&qv); - let rv = r.slice(s![..k as isize, ..m as isize]); - let mut r = ArrayBase::zeros((k, m)); - for ((i, j), val) in r.indexed_iter_mut() { - if i <= j { - *val = rv[(i, j)]; - } - } - Ok((q, r)) + Ok((take_slice(q, n, k), take_slice_upper(&r, k, m))) + } +} + +impl<'a, A, S, Sq, Sr> QR, ArrayBase> for &'a ArrayBase + where A: LapackScalar + Copy + Zero, + S: Data, + Sq: DataOwned + DataMut, + Sr: DataOwned + DataMut +{ + fn qr(self) -> Result<(ArrayBase, ArrayBase)> { + let n = self.rows(); + let m = self.cols(); + let k = ::std::cmp::min(n, m); + let l = self.layout()?; + let mut q = self.to_owned(); + let r = A::qr(l, q.as_allocated_mut()?)?; + let r: Array2<_> = reconstruct(l, r)?; + Ok((take_slice(&q, n, k), take_slice_upper(&r, k, m))) } } From ddc3bb93114de9d72261737c3a667560af434f96 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 4 Jun 2017 17:12:20 +0900 Subject: [PATCH 08/17] Add moc of SVD --- src/impl2/mod.rs | 3 +++ src/impl2/svd.rs | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 src/impl2/svd.rs diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 1feca0dd..fabfe1c6 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -1,8 +1,11 @@ pub mod opnorm; pub mod qr; +pub mod svd; + pub use self::opnorm::*; pub use self::qr::*; +pub use self::svd::*; pub trait LapackScalar: OperatorNorm_ + QR_ {} impl LapackScalar for A where A: OperatorNorm_ + QR_ {} diff --git a/src/impl2/svd.rs b/src/impl2/svd.rs new file mode 100644 index 00000000..e6c0d729 --- /dev/null +++ b/src/impl2/svd.rs @@ -0,0 +1,19 @@ +//! Implement Operator norms for matrices + +use lapack::c; + +use types::*; +use error::*; +use layout::Layout; + +#[repr(u8)] +pub enum FlagSVD { + All = b'A', + OverWrite = b'O', + Separately = b'S', + No = b'N', +} + +pub trait SVD_: Sized { + fn svd(Layout, u_flag: FlagSVD, v_flag: FlagSVD, a: &[Self]) -> Result<()>; +} From 4e65a4703662f44a48b28862a516d30169738c06 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 5 Jun 2017 15:44:35 +0900 Subject: [PATCH 09/17] impl SVD_ --- src/impl2/svd.rs | 57 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/src/impl2/svd.rs b/src/impl2/svd.rs index e6c0d729..b1f8de4a 100644 --- a/src/impl2/svd.rs +++ b/src/impl2/svd.rs @@ -1,19 +1,66 @@ //! Implement Operator norms for matrices use lapack::c; +use num_traits::Zero; use types::*; use error::*; use layout::Layout; #[repr(u8)] -pub enum FlagSVD { +enum FlagSVD { All = b'A', - OverWrite = b'O', - Separately = b'S', + // OverWrite = b'O', + // Separately = b'S', No = b'N', } -pub trait SVD_: Sized { - fn svd(Layout, u_flag: FlagSVD, v_flag: FlagSVD, a: &[Self]) -> Result<()>; +pub struct SVDOutput { + pub s: Vec, + pub u: Option>, + pub vt: Option>, } + +pub trait SVD_: AssociatedReal { + fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; +} + +macro_rules! impl_svd { + ($scalar:ty, $gesvd:path) => { + +impl SVD_ for $scalar { + fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result> { + let (n, m) = l.size(); + let k = ::std::cmp::min(n, m); + let lda = l.lda(); + let (ju, ldu, mut u) = if calc_u { + (FlagSVD::All, m, vec![Self::zero(); (m*m) as usize]) + } else { + (FlagSVD::No, 0, Vec::new()) + }; + let (jvt, ldvt, mut vt) = if calc_vt { + (FlagSVD::All, m, vec![Self::zero(); (m*m) as usize]) + } else { + (FlagSVD::No, 0, Vec::new()) + }; + let mut s = vec![Self::Real::zero(); k as usize]; + let mut superb = vec![Self::Real::zero(); (k-2) as usize]; + let info = $gesvd(l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb); + if info == 0 { + Ok(SVDOutput { + s: s, + u: if ldu > 0 { Some(u) } else { None }, + vt: if ldvt > 0 { Some(vt) } else { None }, + }) + } else { + Err(LapackError::new(info).into()) + } + } +} + +}} // impl_svd! + +impl_svd!(f64, c::dgesvd); +impl_svd!(f32, c::sgesvd); +impl_svd!(c64, c::zgesvd); +impl_svd!(c32, c::cgesvd); From ec4ff5a038998d8f4603b3b96a43eba50d9f6648 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 5 Jun 2017 15:48:06 +0900 Subject: [PATCH 10/17] Split traits.rs --- src/lib.rs | 3 ++- src/opnorm.rs | 36 ++++++++++++++++++++++++++++++++++++ src/prelude.rs | 4 +++- src/{traits.rs => qr.rs} | 31 +------------------------------ 4 files changed, 42 insertions(+), 32 deletions(-) create mode 100644 src/opnorm.rs rename src/{traits.rs => qr.rs} (77%) diff --git a/src/lib.rs b/src/lib.rs index 1d4f537e..c0026c80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,7 +48,8 @@ pub mod layout; pub mod impls; pub mod impl2; -pub mod traits; +pub mod qr; +pub mod opnorm; pub mod vector; pub mod matrix; diff --git a/src/opnorm.rs b/src/opnorm.rs new file mode 100644 index 00000000..e2996d22 --- /dev/null +++ b/src/opnorm.rs @@ -0,0 +1,36 @@ + +use ndarray::*; + +use super::types::*; +use super::error::*; +use super::layout::*; + +pub use impl2::NormType; +use impl2::LapackScalar; + +pub trait OperationNorm { + type Output; + fn opnorm(&self, t: NormType) -> Self::Output; + fn opnorm_one(&self) -> Self::Output { + self.opnorm(NormType::One) + } + fn opnorm_inf(&self) -> Self::Output { + self.opnorm(NormType::Infinity) + } + fn opnorm_fro(&self) -> Self::Output { + self.opnorm(NormType::Frobenius) + } +} + +impl OperationNorm for ArrayBase + where A: LapackScalar + AssociatedReal, + S: Data +{ + type Output = Result; + + fn opnorm(&self, t: NormType) -> Self::Output { + let l = self.layout()?; + let a = self.as_allocated()?; + Ok(A::opnorm(t, l, a)) + } +} diff --git a/src/prelude.rs b/src/prelude.rs index 9fb296b2..29d6187a 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -5,4 +5,6 @@ pub use hermite::HermiteMatrix; pub use triangular::*; pub use util::*; pub use assert::*; -pub use traits::*; + +pub use qr::*; +pub use opnorm::*; diff --git a/src/traits.rs b/src/qr.rs similarity index 77% rename from src/traits.rs rename to src/qr.rs index a0165cbe..032e1f72 100644 --- a/src/traits.rs +++ b/src/qr.rs @@ -1,40 +1,11 @@ -pub use impl2::LapackScalar; -pub use impl2::NormType; - use num_traits::Zero; use ndarray::*; -use super::types::*; use super::error::*; use super::layout::*; -pub trait OperationNorm { - type Output; - fn opnorm(&self, t: NormType) -> Self::Output; - fn opnorm_one(&self) -> Self::Output { - self.opnorm(NormType::One) - } - fn opnorm_inf(&self) -> Self::Output { - self.opnorm(NormType::Infinity) - } - fn opnorm_fro(&self) -> Self::Output { - self.opnorm(NormType::Frobenius) - } -} - -impl OperationNorm for ArrayBase - where A: LapackScalar + AssociatedReal, - S: Data -{ - type Output = Result; - - fn opnorm(&self, t: NormType) -> Self::Output { - let l = self.layout()?; - let a = self.as_allocated()?; - Ok(A::opnorm(t, l, a)) - } -} +use impl2::LapackScalar; pub trait QR { fn qr(self) -> Result<(Q, R)>; From a07846b9337a32e5d280f0a53f42eeab1c29f294 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 5 Jun 2017 16:18:12 +0900 Subject: [PATCH 11/17] Replace SVD --- src/impl2/mod.rs | 4 ++-- src/impl2/svd.rs | 4 ++-- src/lib.rs | 1 + src/matrix.rs | 18 ------------------ src/prelude.rs | 1 + src/svd.rs | 40 ++++++++++++++++++++++++++++++++++++++++ tests/svd.rs | 12 +++++++----- 7 files changed, 53 insertions(+), 27 deletions(-) create mode 100644 src/svd.rs diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index fabfe1c6..2ecd7a0a 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -7,5 +7,5 @@ pub use self::opnorm::*; pub use self::qr::*; pub use self::svd::*; -pub trait LapackScalar: OperatorNorm_ + QR_ {} -impl LapackScalar for A where A: OperatorNorm_ + QR_ {} +pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {} +impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {} diff --git a/src/impl2/svd.rs b/src/impl2/svd.rs index b1f8de4a..54fc1828 100644 --- a/src/impl2/svd.rs +++ b/src/impl2/svd.rs @@ -30,7 +30,7 @@ macro_rules! impl_svd { impl SVD_ for $scalar { fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result> { - let (n, m) = l.size(); + let (m, n) = l.size(); let k = ::std::cmp::min(n, m); let lda = l.lda(); let (ju, ldu, mut u) = if calc_u { @@ -39,7 +39,7 @@ impl SVD_ for $scalar { (FlagSVD::No, 0, Vec::new()) }; let (jvt, ldvt, mut vt) = if calc_vt { - (FlagSVD::All, m, vec![Self::zero(); (m*m) as usize]) + (FlagSVD::All, n, vec![Self::zero(); (n*n) as usize]) } else { (FlagSVD::No, 0, Vec::new()) }; diff --git a/src/lib.rs b/src/lib.rs index c0026c80..bda5683e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,6 +49,7 @@ pub mod impls; pub mod impl2; pub mod qr; +pub mod svd; pub mod opnorm; pub mod vector; diff --git a/src/matrix.rs b/src/matrix.rs index 21f530f2..b135b6f7 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -21,8 +21,6 @@ pub trait Matrix: Sized { fn size(&self) -> (usize, usize); /// Layout (C/Fortran) of matrix fn layout(&self) -> Result; - /// singular-value decomposition (SVD) - fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>; /// LU decomposition fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>; /// permutate matrix (inplace) @@ -74,18 +72,6 @@ impl Matrix for Array { fn layout(&self) -> Result { check_layout(self.strides()) } - fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> { - let (n, m) = self.size(); - let layout = self.layout()?; - let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?; - let sv = Array::from_vec(s); - let ua = Array::from_vec(u).into_shape((n, n)).unwrap(); - let va = Array::from_vec(vt).into_shape((m, m)).unwrap(); - match layout { - Layout::RowMajor => Ok((ua, sv, va)), - Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())), - } - } fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { let (n, m) = self.size(); let k = min(n, m); @@ -129,10 +115,6 @@ impl Matrix for RcArray { fn layout(&self) -> Result { check_layout(self.strides()) } - fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> { - let (u, s, v) = self.into_owned().svd()?; - Ok((u.into_shared(), s.into_shared(), v.into_shared())) - } fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { let (p, l, u) = self.into_owned().lu()?; Ok((p, l.into_shared(), u.into_shared())) diff --git a/src/prelude.rs b/src/prelude.rs index 29d6187a..06e15551 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -7,4 +7,5 @@ pub use util::*; pub use assert::*; pub use qr::*; +pub use svd::*; pub use opnorm::*; diff --git a/src/svd.rs b/src/svd.rs new file mode 100644 index 00000000..27533940 --- /dev/null +++ b/src/svd.rs @@ -0,0 +1,40 @@ + +use ndarray::*; + +use super::error::*; +use super::layout::{Layout, AllocatedArray, AllocatedArrayMut}; +use impl2::LapackScalar; + +pub trait SVD { + fn svd(self, calc_u: bool, calc_vt: bool) -> Result<(Option, S, Option)>; +} + +impl SVD, ArrayBase, ArrayBase> for ArrayBase + where A: LapackScalar, + S: DataMut, + Su: DataOwned, + Svt: DataOwned, + Ss: DataOwned +{ + fn svd(mut self, + calc_u: bool, + calc_vt: bool) + -> Result<(Option>, ArrayBase, Option>)> { + let n = self.rows(); + let m = self.cols(); + let l = self.layout()?; + let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?; + let (u, vt) = match l { + Layout::C(_) => { + (svd_res.u.map(|u| ArrayBase::from_shape_vec((n, n), u).unwrap()), + svd_res.vt.map(|vt| ArrayBase::from_shape_vec((m, m), vt).unwrap())) + } + Layout::F(_) => { + (svd_res.u.map(|u| ArrayBase::from_shape_vec((n, n).f(), u).unwrap()), + svd_res.vt.map(|vt| ArrayBase::from_shape_vec((m, m).f(), vt).unwrap())) + } + }; + let s = ArrayBase::from_vec(svd_res.s); + Ok((u, s, vt)) + } +} diff --git a/tests/svd.rs b/tests/svd.rs index 1a4d8b2c..ccd16048 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -9,11 +9,13 @@ fn $funcname() { use ndarray_linalg::prelude::*; let a = $random($n, $m, $t); let answer = a.clone(); - println!("a = \n{}", &a); - let (u, s, vt) = a.svd().unwrap(); - println!("u = \n{}", &u); - println!("s = \n{}", &s); - println!("v = \n{}", &vt); + println!("a = \n{:?}", &a); + let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap(); + let u: Array2<_> = u.unwrap(); + let vt: Array2<_> = vt.unwrap(); + println!("u = \n{:?}", &u); + println!("s = \n{:?}", &s); + println!("v = \n{:?}", &vt); let mut sm = Array::zeros(($n, $m)); for i in 0..min($n, $m) { sm[(i, i)] = s[i]; From 80b4cf7d3d60998039dedea8704da505b0560aac Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 5 Jun 2017 16:34:34 +0900 Subject: [PATCH 12/17] Add Layout::resized --- src/layout.rs | 7 +++++++ src/svd.rs | 17 ++++------------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/layout.rs b/src/layout.rs index 522af18c..6c098115 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -22,6 +22,13 @@ impl Layout { } } + pub fn resized(&self, row: Row, col: Col) -> Layout { + match *self { + Layout::C(_) => Layout::C((row, col)), + Layout::F(_) => Layout::F((col, row)), + } + } + pub fn lda(&self) -> LDA { match *self { Layout::C((_, lda)) => lda, diff --git a/src/svd.rs b/src/svd.rs index 27533940..204ef75a 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -2,7 +2,7 @@ use ndarray::*; use super::error::*; -use super::layout::{Layout, AllocatedArray, AllocatedArrayMut}; +use super::layout::*; use impl2::LapackScalar; pub trait SVD { @@ -20,20 +20,11 @@ impl SVD, ArrayBase, ArrayBase Result<(Option>, ArrayBase, Option>)> { - let n = self.rows(); - let m = self.cols(); let l = self.layout()?; let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?; - let (u, vt) = match l { - Layout::C(_) => { - (svd_res.u.map(|u| ArrayBase::from_shape_vec((n, n), u).unwrap()), - svd_res.vt.map(|vt| ArrayBase::from_shape_vec((m, m), vt).unwrap())) - } - Layout::F(_) => { - (svd_res.u.map(|u| ArrayBase::from_shape_vec((n, n).f(), u).unwrap()), - svd_res.vt.map(|vt| ArrayBase::from_shape_vec((m, m).f(), vt).unwrap())) - } - }; + let (n, m) = l.size(); + let u = svd_res.u.map(|u| reconstruct(l.resized(n, n), u).unwrap()); + let vt = svd_res.vt.map(|vt| reconstruct(l.resized(m, m), vt).unwrap()); let s = ArrayBase::from_vec(svd_res.s); Ok((u, s, vt)) } From 368d577c9111b1bd3eb2b71f7fabbcd9ece23a0e Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 5 Jun 2017 16:39:53 +0900 Subject: [PATCH 13/17] Add SVD impl for &ArrayBase, &mut ArrayBase --- src/svd.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/svd.rs b/src/svd.rs index 204ef75a..1aa1488c 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -15,6 +15,38 @@ impl SVD, ArrayBase, ArrayBase, Svt: DataOwned, Ss: DataOwned +{ + fn svd(mut self, + calc_u: bool, + calc_vt: bool) + -> Result<(Option>, ArrayBase, Option>)> { + (&mut self).svd(calc_u, calc_vt) + } +} + +impl<'a, A, S, Su, Svt, Ss> SVD, ArrayBase, ArrayBase> for &'a ArrayBase + where A: LapackScalar + Clone, + S: Data, + Su: DataOwned, + Svt: DataOwned, + Ss: DataOwned +{ + fn svd(self, + calc_u: bool, + calc_vt: bool) + -> Result<(Option>, ArrayBase, Option>)> { + let a = self.to_owned(); + a.svd(calc_u, calc_vt) + } +} + +impl<'a, A, S, Su, Svt, Ss> SVD, ArrayBase, ArrayBase> + for &'a mut ArrayBase + where A: LapackScalar, + S: DataMut, + Su: DataOwned, + Svt: DataOwned, + Ss: DataOwned { fn svd(mut self, calc_u: bool, From 06d6f76439ef576331a172022b91148c075df201 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 7 Jun 2017 16:58:52 +0900 Subject: [PATCH 14/17] Impl solve --- src/impl2/mod.rs | 12 ++++++++++ src/impl2/solve.rs | 57 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 src/impl2/solve.rs diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 2ecd7a0a..8351f662 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -2,10 +2,22 @@ pub mod opnorm; pub mod qr; pub mod svd; +pub mod solve; pub use self::opnorm::*; pub use self::qr::*; pub use self::svd::*; +pub use self::solve::*; + +use super::error::*; pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {} impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {} + +pub fn into_result(info: i32, val: T) -> Result { + if info == 0 { + Ok(val) + } else { + Err(LapackError::new(info).into()) + } +} diff --git a/src/impl2/solve.rs b/src/impl2/solve.rs new file mode 100644 index 00000000..e763f704 --- /dev/null +++ b/src/impl2/solve.rs @@ -0,0 +1,57 @@ + +use lapack::c; + +use types::*; +use error::*; +use layout::Layout; + +use super::into_result; + +pub type Pivot = Vec; + +#[repr(u8)] +pub enum Transpose { + No = b'N', + Transpose = b'T', + Hermite = b'C', +} + +pub trait Solve_: Sized { + fn lu(Layout, a: &mut [Self]) -> Result; + fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>; + fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solve { + ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { + +impl Solve_ for $scalar { + fn lu(l: Layout, a: &mut [Self]) -> Result { + let (row, col) = l.size(); + let k = ::std::cmp::min(row, col); + let mut ipiv = vec![0; k as usize]; + let info = $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv); + into_result(info, ipiv) + } + + fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + let (n, _) = l.size(); + let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv); + into_result(info, ()) + } + + fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + let nrhs = 1; + let ldb = 1; + let info = $getrs(l.lapacke_layout(), t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb); + into_result(info, ()) + } +} + +}} // impl_solve! + +impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs); +impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs); +impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs); +impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs); From e9d745102baf795ceaed00f56c6218e99fbd69a0 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 7 Jun 2017 17:01:15 +0900 Subject: [PATCH 15/17] Use into_result --- src/impl2/qr.rs | 14 ++++---------- src/impl2/svd.rs | 16 +++++++--------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/impl2/qr.rs b/src/impl2/qr.rs index db8f21d9..714135d0 100644 --- a/src/impl2/qr.rs +++ b/src/impl2/qr.rs @@ -8,6 +8,8 @@ use types::*; use error::*; use layout::Layout; +use super::into_result; + pub trait QR_: Sized { fn householder(Layout, a: &mut [Self]) -> Result>; fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>; @@ -22,22 +24,14 @@ impl QR_ for $scalar { let k = min(row, col); let mut tau = vec![Self::zero(); k as usize]; let info = $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau); - if info == 0 { - Ok(tau) - } else { - Err(LapackError::new(info).into()) - } + into_result(info, tau) } fn q(l: Layout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { let (row, col) = l.size(); let k = min(row, col); let info = $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau); - if info == 0 { - Ok(()) - } else { - Err(LapackError::new(info).into()) - } + into_result(info, ()) } fn qr(l: Layout, mut a: &mut [Self]) -> Result> { diff --git a/src/impl2/svd.rs b/src/impl2/svd.rs index 54fc1828..1151f8c8 100644 --- a/src/impl2/svd.rs +++ b/src/impl2/svd.rs @@ -7,6 +7,8 @@ use types::*; use error::*; use layout::Layout; +use super::into_result; + #[repr(u8)] enum FlagSVD { All = b'A', @@ -46,15 +48,11 @@ impl SVD_ for $scalar { let mut s = vec![Self::Real::zero(); k as usize]; let mut superb = vec![Self::Real::zero(); (k-2) as usize]; let info = $gesvd(l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb); - if info == 0 { - Ok(SVDOutput { - s: s, - u: if ldu > 0 { Some(u) } else { None }, - vt: if ldvt > 0 { Some(vt) } else { None }, - }) - } else { - Err(LapackError::new(info).into()) - } + into_result(info, SVDOutput { + s: s, + u: if ldu > 0 { Some(u) } else { None }, + vt: if ldvt > 0 { Some(vt) } else { None }, + }) } } From 7e51ce6f2b36d0ff9da7c4dbda732ef521bbf0d6 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 7 Jun 2017 18:18:50 +0900 Subject: [PATCH 16/17] Revert "Impl solve" This reverts commit 06d6f76439ef576331a172022b91148c075df201. --- src/impl2/mod.rs | 12 ---------- src/impl2/solve.rs | 57 ---------------------------------------------- 2 files changed, 69 deletions(-) delete mode 100644 src/impl2/solve.rs diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 8351f662..2ecd7a0a 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -2,22 +2,10 @@ pub mod opnorm; pub mod qr; pub mod svd; -pub mod solve; pub use self::opnorm::*; pub use self::qr::*; pub use self::svd::*; -pub use self::solve::*; - -use super::error::*; pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {} impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {} - -pub fn into_result(info: i32, val: T) -> Result { - if info == 0 { - Ok(val) - } else { - Err(LapackError::new(info).into()) - } -} diff --git a/src/impl2/solve.rs b/src/impl2/solve.rs deleted file mode 100644 index e763f704..00000000 --- a/src/impl2/solve.rs +++ /dev/null @@ -1,57 +0,0 @@ - -use lapack::c; - -use types::*; -use error::*; -use layout::Layout; - -use super::into_result; - -pub type Pivot = Vec; - -#[repr(u8)] -pub enum Transpose { - No = b'N', - Transpose = b'T', - Hermite = b'C', -} - -pub trait Solve_: Sized { - fn lu(Layout, a: &mut [Self]) -> Result; - fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>; - fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>; -} - -macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { - -impl Solve_ for $scalar { - fn lu(l: Layout, a: &mut [Self]) -> Result { - let (row, col) = l.size(); - let k = ::std::cmp::min(row, col); - let mut ipiv = vec![0; k as usize]; - let info = $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv); - into_result(info, ipiv) - } - - fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { - let (n, _) = l.size(); - let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv); - into_result(info, ()) - } - - fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { - let (n, _) = l.size(); - let nrhs = 1; - let ldb = 1; - let info = $getrs(l.lapacke_layout(), t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb); - into_result(info, ()) - } -} - -}} // impl_solve! - -impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs); -impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs); -impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs); -impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs); From 0e2d9a2f0d6b29984547ef314ae607431b92c453 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 7 Jun 2017 18:19:33 +0900 Subject: [PATCH 17/17] Add into_result --- src/impl2/mod.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 2ecd7a0a..2c6fcbcc 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -7,5 +7,15 @@ pub use self::opnorm::*; pub use self::qr::*; pub use self::svd::*; +use super::error::*; + pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {} impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {} + +pub fn into_result(info: i32, val: T) -> Result { + if info == 0 { + Ok(val) + } else { + Err(LapackError::new(info).into()) + } +}