From d285bd655f7ac2c2c60d245de32e747036652cf6 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 12 Nov 2016 03:11:30 +0900 Subject: [PATCH 01/14] Use Ix1/Ix2 instead of Ix #15 --- src/hermite.rs | 5 ++--- src/matrix.rs | 7 +++---- src/square.rs | 5 ++--- src/vector.rs | 5 ++--- tests/inv.rs | 2 +- tests/qr.rs | 2 +- tests/ssqrt.rs | 2 +- tests/svd.rs | 2 +- 8 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/hermite.rs b/src/hermite.rs index 1c97c4b9..42ad3796 100644 --- a/src/hermite.rs +++ b/src/hermite.rs @@ -1,7 +1,6 @@ //! Define trait for Hermite matrices -use ndarray::prelude::*; -use ndarray::LinalgScalar; +use ndarray::{Ix2, Array, LinalgScalar}; use num_traits::float::Float; use matrix::Matrix; @@ -21,7 +20,7 @@ pub trait HermiteMatrix: SquareMatrix + Matrix { fn ssqrt(self) -> Result; } -impl HermiteMatrix for Array +impl HermiteMatrix for Array where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + LinalgScalar + Float { fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> { diff --git a/src/matrix.rs b/src/matrix.rs index cc515de0..5e83608b 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -1,8 +1,7 @@ //! Define trait for general matrix use std::cmp::min; -use ndarray::prelude::*; -use ndarray::LinalgScalar; +use ndarray::{Ix1, Ix2, Array, Axis, LinalgScalar}; use error::LapackError; use qr::ImplQR; @@ -27,11 +26,11 @@ pub trait Matrix: Sized { fn qr(self) -> Result<(Self, Self), LapackError>; } -impl Matrix for Array +impl Matrix for Array where A: ImplQR + ImplSVD + ImplNorm + LinalgScalar { type Scalar = A; - type Vector = Array; + type Vector = Array; fn size(&self) -> (usize, usize) { (self.rows(), self.cols()) } diff --git a/src/square.rs b/src/square.rs index e54be4f1..35bc3ae1 100644 --- a/src/square.rs +++ b/src/square.rs @@ -1,7 +1,6 @@ //! Define trait for Hermite matrices -use ndarray::prelude::*; -use ndarray::LinalgScalar; +use ndarray::{Ix2, Array, LinalgScalar}; use num_traits::float::Float; use matrix::Matrix; @@ -37,7 +36,7 @@ pub trait SquareMatrix: Matrix { } } -impl SquareMatrix for Array +impl SquareMatrix for Array where A: ImplQR + ImplNorm + ImplSVD + ImplSolve + LinalgScalar + Float { fn inv(self) -> Result { diff --git a/src/vector.rs b/src/vector.rs index 6142cfa7..f264eb05 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -1,7 +1,6 @@ //! Define trait for vectors -use ndarray::prelude::*; -use ndarray::LinalgScalar; +use ndarray::{LinalgScalar, Array, Ix1}; use num_traits::float::Float; /// Methods for vectors @@ -11,7 +10,7 @@ pub trait Vector { fn norm(&self) -> Self::Scalar; } -impl Vector for Array { +impl Vector for Array { type Scalar = A; fn norm(&self) -> Self::Scalar { self.dot(&self).sqrt() diff --git a/tests/inv.rs b/tests/inv.rs index 723bba2d..4c4d73cf 100644 --- a/tests/inv.rs +++ b/tests/inv.rs @@ -9,7 +9,7 @@ use ndarray_linalg::prelude::*; use rand::distributions::*; use ndarray_rand::RandomExt; -fn all_close(a: Array, b: Array) { +fn all_close(a: Array, b: Array) { if !a.all_close(&b, 1.0e-7) { panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n", a, diff --git a/tests/qr.rs b/tests/qr.rs index 670eef1f..c6677d5d 100644 --- a/tests/qr.rs +++ b/tests/qr.rs @@ -9,7 +9,7 @@ use ndarray_linalg::prelude::*; use rand::distributions::*; use ndarray_rand::RandomExt; -fn all_close(a: Array, b: Array) { +fn all_close(a: Array, b: Array) { if !a.all_close(&b, 1.0e-7) { panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n", a, diff --git a/tests/ssqrt.rs b/tests/ssqrt.rs index 136c2987..af2cb889 100644 --- a/tests/ssqrt.rs +++ b/tests/ssqrt.rs @@ -9,7 +9,7 @@ use ndarray::prelude::*; use ndarray_linalg::prelude::*; use ndarray_rand::RandomExt; -fn all_close(a: &Array, b: &Array) { +fn all_close(a: &Array, b: &Array) { if !a.all_close(b, 1.0e-7) { panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n", a, diff --git a/tests/svd.rs b/tests/svd.rs index 2ac67c67..87a37f4b 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -10,7 +10,7 @@ use ndarray_linalg::prelude::*; use rand::distributions::*; use ndarray_rand::RandomExt; -fn all_close(a: Array, b: Array) { +fn all_close(a: Array, b: Array) { if !a.all_close(&b, 1.0e-7) { panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n", a, From 7452d8c704ff611060cee4710e904d284121da2a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 13:28:32 +0900 Subject: [PATCH 02/14] Use ndarray v0.7 --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 591db190..c5bfbd00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,9 +14,9 @@ lapack = "0.11.1" num-traits = "0.1.36" [dependencies.ndarray] -version = "0.6.9" +version = "0.7" features = ["blas"] [dev-dependencies] rand = "0.3.14" -ndarray-rand = "0.2.0" +ndarray-rand = "0.3" From b3dd41c6669aad559b7536df50acedabba969331 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 13:34:10 +0900 Subject: [PATCH 03/14] s/(Ix, Ix)/Ix2/g --- tests/cholesky.rs | 2 +- tests/lu.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cholesky.rs b/tests/cholesky.rs index ecb2dcf9..27f95b1e 100644 --- a/tests/cholesky.rs +++ b/tests/cholesky.rs @@ -9,7 +9,7 @@ use ndarray::prelude::*; use ndarray_linalg::prelude::*; use ndarray_rand::RandomExt; -fn all_close(a: Array, b: Array) { +fn all_close(a: Array, b: Array) { if !a.all_close(&b, 1.0e-7) { panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n", a, diff --git a/tests/lu.rs b/tests/lu.rs index a4ef1a2d..a1e84713 100644 --- a/tests/lu.rs +++ b/tests/lu.rs @@ -9,7 +9,7 @@ use ndarray_linalg::prelude::*; use rand::distributions::*; use ndarray_rand::RandomExt; -fn all_close(a: Array, b: Array) { +fn all_close(a: Array, b: Array) { if !a.all_close(&b, 1.0e-7) { panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n", a, @@ -68,7 +68,7 @@ test_permutate_t!(permutate_4x3_t, &[[1., 4., 7., 10.], [2., 5., 8., 11.], [3., 6., 9., 12.]], &[[10., 11., 12.], [4., 5., 6.], [7., 8., 9.], [1., 2., 3.]]); -fn test_lu(a: Array) { +fn test_lu(a: Array) { println!("a = \n{:?}", &a); let (p, l, u) = a.clone().lu().unwrap(); println!("P = \n{:?}", &p); From 74496ae78c2ea8d785fa9a74e62d90e12980fcc8 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 13:36:00 +0900 Subject: [PATCH 04/14] Bug fixed --- src/matrix.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matrix.rs b/src/matrix.rs index 0a094e72..c868afd8 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -46,7 +46,7 @@ impl Matrix for Array where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + LinalgScalar + Debug { type Scalar = A; - type Vector = Array; + type Vector = Array; type Permutator = Vec; fn size(&self) -> (usize, usize) { From 8ec64f5fee93a77640adbc7cb6a91e6fd1ee9dbb Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 14:12:43 +0900 Subject: [PATCH 05/14] Add StrideError --- src/error.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/error.rs b/src/error.rs index 1fb7f5f5..b6f552e4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -44,10 +44,26 @@ impl error::Error for NotSquareError { } } +#[derive(Debug)] +pub struct StrideError {} + +impl fmt::Display for StrideError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "invalid stride") + } +} + +impl error::Error for StrideError { + fn description(&self) -> &str { + "invalid stride" + } +} + #[derive(Debug)] pub enum LinalgError { NotSquare(NotSquareError), Lapack(LapackError), + Stride(StrideError), } impl fmt::Display for LinalgError { @@ -55,6 +71,7 @@ impl fmt::Display for LinalgError { match *self { LinalgError::NotSquare(ref err) => err.fmt(f), LinalgError::Lapack(ref err) => err.fmt(f), + LinalgError::Stride(ref err) => err.fmt(f), } } } @@ -64,6 +81,7 @@ impl error::Error for LinalgError { match *self { LinalgError::NotSquare(ref err) => err.description(), LinalgError::Lapack(ref err) => err.description(), + LinalgError::Stride(ref err) => err.description(), } } } @@ -79,3 +97,9 @@ impl From for LinalgError { LinalgError::Lapack(err) } } + +impl From for LinalgError { + fn from(err: StrideError) -> LinalgError { + LinalgError::Stride(err) + } +} From 6295216dcc00257438ba7df3986da6e33c412c70 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 14:20:56 +0900 Subject: [PATCH 06/14] Change signature of layout() (stride check is not implemented) --- src/hermite.rs | 3 +-- src/matrix.rs | 28 ++++++++++++++-------------- src/square.rs | 2 +- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/hermite.rs b/src/hermite.rs index 3d1e76b4..e490e332 100644 --- a/src/hermite.rs +++ b/src/hermite.rs @@ -49,9 +49,8 @@ impl HermiteMatrix for Array } fn cholesky(self) -> Result { try!(self.check_square()); - println!("layout = {:?}", self.layout()); let (n, _) = self.size(); - let layout = self.layout(); + let layout = self.layout()?; let a = try!(ImplCholesky::cholesky(layout, n, self.into_raw_vec())); let mut c = match layout { Layout::RowMajor => Array::from_vec(a).into_shape((n, n)).unwrap(), diff --git a/src/matrix.rs b/src/matrix.rs index 7979f767..f3dcf20f 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -6,7 +6,7 @@ use ndarray::prelude::*; use ndarray::LinalgScalar; use lapack::c::Layout; -use error::LapackError; +use error::{LinalgError, StrideError}; use qr::ImplQR; use svd::ImplSVD; use norm::ImplNorm; @@ -20,7 +20,7 @@ pub trait Matrix: Sized { /// number of (rows, columns) fn size(&self) -> (usize, usize); /// Layout (C/Fortran) of matrix - fn layout(&self) -> Layout; + fn layout(&self) -> Result; /// Operator norm for L-1 norm fn norm_1(&self) -> Self::Scalar; /// Operator norm for L-inf norm @@ -28,11 +28,11 @@ pub trait Matrix: Sized { /// Frobenius norm fn norm_f(&self) -> Self::Scalar; /// singular-value decomposition (SVD) - fn svd(self) -> Result<(Self, Self::Vector, Self), LapackError>; + fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>; /// QR decomposition - fn qr(self) -> Result<(Self, Self), LapackError>; + fn qr(self) -> Result<(Self, Self), LinalgError>; /// LU decomposition - fn lu(self) -> Result<(Self::Permutator, Self, Self), LapackError>; + fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>; /// permutate matrix (inplace) fn permutate(&mut self, p: &Self::Permutator); /// permutate matrix (outplace) @@ -52,12 +52,12 @@ impl Matrix for Array fn size(&self) -> (usize, usize) { (self.rows(), self.cols()) } - fn layout(&self) -> Layout { + fn layout(&self) -> Result { let strides = self.strides(); if strides[0] < strides[1] { - Layout::ColumnMajor + Ok(Layout::ColumnMajor) } else { - Layout::RowMajor + Ok(Layout::RowMajor) } } fn norm_1(&self) -> Self::Scalar { @@ -82,7 +82,7 @@ impl Matrix for Array let (m, n) = self.size(); ImplNorm::norm_f(m, n, self.clone().into_raw_vec()) } - fn svd(self) -> Result<(Self, Self::Vector, Self), LapackError> { + fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> { let strides = self.strides(); let (m, n) = if strides[0] > strides[1] { self.size() @@ -102,7 +102,7 @@ impl Matrix for Array Ok((ua, sv, va)) } } - fn qr(self) -> Result<(Self, Self), LapackError> { + fn qr(self) -> Result<(Self, Self), LinalgError> { let (n, m) = self.size(); let strides = self.strides(); let k = min(n, m); @@ -136,19 +136,19 @@ impl Matrix for Array } Ok((qm, rm)) } - fn lu(self) -> Result<(Self::Permutator, Self, Self), LapackError> { + fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { let (n, m) = self.size(); println!("n={}, m={}", n, m); let k = min(n, m); - let (p, mut a) = match self.layout() { + let (p, mut a) = match self.layout()? { Layout::ColumnMajor => { println!("ColumnMajor"); - let (p, l) = ImplSolve::lu(self.layout(), n, m, self.clone().into_raw_vec())?; + let (p, l) = ImplSolve::lu(self.layout()?, n, m, self.clone().into_raw_vec())?; (p, Array::from_vec(l).into_shape((m, n)).unwrap().reversed_axes()) } Layout::RowMajor => { println!("RowMajor"); - let (p, l) = ImplSolve::lu(self.layout(), n, m, self.clone().into_raw_vec())?; + let (p, l) = ImplSolve::lu(self.layout()?, n, m, self.clone().into_raw_vec())?; (p, Array::from_vec(l).into_shape((n, m)).unwrap()) } }; diff --git a/src/square.rs b/src/square.rs index ba4a21cd..666003e7 100644 --- a/src/square.rs +++ b/src/square.rs @@ -43,7 +43,7 @@ impl SquareMatrix for Array try!(self.check_square()); let (n, _) = self.size(); let is_fortran_align = self.strides()[0] > self.strides()[1]; - let a = try!(ImplSolve::inv(self.layout(), n, self.into_raw_vec())); + let a = ImplSolve::inv(self.layout()?, n, self.into_raw_vec())?; let m = Array::from_vec(a).into_shape((n, n)).unwrap(); if is_fortran_align { Ok(m) From 6654b3f69d6492e5efc6896c4e73e855b45c4b2a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 14:26:00 +0900 Subject: [PATCH 07/14] Add stride check --- src/error.rs | 8 ++++++-- src/matrix.rs | 6 ++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/error.rs b/src/error.rs index b6f552e4..126fc237 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,6 +2,7 @@ use std::error; use std::fmt; +use ndarray::Ixs; #[derive(Debug)] pub struct LapackError { @@ -45,11 +46,14 @@ impl error::Error for NotSquareError { } #[derive(Debug)] -pub struct StrideError {} +pub struct StrideError { + pub s0: Ixs, + pub s1: Ixs, +} impl fmt::Display for StrideError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "invalid stride") + write!(f, "invalid stride: s0={}, s1={}", self.s0, self.s1) } } diff --git a/src/matrix.rs b/src/matrix.rs index f3dcf20f..eb1e1131 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -54,6 +54,12 @@ impl Matrix for Array } fn layout(&self) -> Result { let strides = self.strides(); + if min(strides[0], strides[1]) != 1 { + return Err(StrideError { + s0: strides[0], + s1: strides[1], + });; + } if strides[0] < strides[1] { Ok(Layout::ColumnMajor) } else { From b12d68e7eea419943fa1906b91e13ea11e68aea0 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 20:30:19 +0900 Subject: [PATCH 08/14] Small rev for layout --- src/square.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/square.rs b/src/square.rs index 666003e7..07067790 100644 --- a/src/square.rs +++ b/src/square.rs @@ -3,6 +3,7 @@ use ndarray::{Ix2, Array, LinalgScalar}; use std::fmt::Debug; use num_traits::float::Float; +use lapack::c::Layout; use matrix::Matrix; use error::{LinalgError, NotSquareError}; @@ -42,13 +43,12 @@ impl SquareMatrix for Array fn inv(self) -> Result { try!(self.check_square()); let (n, _) = self.size(); - let is_fortran_align = self.strides()[0] > self.strides()[1]; - let a = ImplSolve::inv(self.layout()?, n, self.into_raw_vec())?; + let layout = self.layout()?; + let a = ImplSolve::inv(layout, n, self.into_raw_vec())?; let m = Array::from_vec(a).into_shape((n, n)).unwrap(); - if is_fortran_align { - Ok(m) - } else { - Ok(m.reversed_axes()) + match layout { + Layout::RowMajor => Ok(m), + Layout::ColumnMajor => Ok(m.reversed_axes()), } } fn trace(&self) -> Result { From 31da57b4e6c16263f503596b16190bf37df54824 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 20:38:27 +0900 Subject: [PATCH 09/14] Use Layout in eigh --- src/eigh.rs | 18 ++++-------------- src/hermite.rs | 8 ++++++-- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/eigh.rs b/src/eigh.rs index 70bd55b2..bc5f37d2 100644 --- a/src/eigh.rs +++ b/src/eigh.rs @@ -1,30 +1,20 @@ //! Implement eigenvalue decomposition of Hermite matrix -use lapack::fortran::*; +use lapack::c::*; use num_traits::Zero; use error::LapackError; pub trait ImplEigh: Sized { - fn eigh(n: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError>; + fn eigh(layout: Layout, n: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError>; } macro_rules! impl_eigh { ($scalar:ty, $syev:path) => { impl ImplEigh for $scalar { - fn eigh(n: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError> { + fn eigh(layout: Layout, n: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError> { let mut w = vec![Self::zero(); n]; - let mut work = vec![Self::zero(); 4 * n]; - let mut info = 0; - $syev(b'V', - b'U', - n as i32, - &mut a, - n as i32, - &mut w, - &mut work, - 4 * n as i32, - &mut info); + let info = $syev(layout, b'V', b'U', n as i32, &mut a, n as i32, &mut w); if info == 0 { Ok((w, a)) } else { diff --git a/src/hermite.rs b/src/hermite.rs index e490e332..1b24bf67 100644 --- a/src/hermite.rs +++ b/src/hermite.rs @@ -30,10 +30,14 @@ impl HermiteMatrix for Array { fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> { try!(self.check_square()); + let layout = self.layout()?; let (rows, cols) = self.size(); - let (w, a) = try!(ImplEigh::eigh(rows, self.into_raw_vec())); + let (w, a) = ImplEigh::eigh(layout, rows, self.into_raw_vec())?; let ea = Array::from_vec(w); - let va = Array::from_vec(a).into_shape((rows, cols)).unwrap().reversed_axes(); + let va = match layout { + Layout::ColumnMajor => Array::from_vec(a).into_shape((rows, cols)).unwrap().reversed_axes(), + Layout::RowMajor => Array::from_vec(a).into_shape((rows, cols)).unwrap(), + }; Ok((ea, va)) } fn ssqrt(self) -> Result { From 133de4a35a7950d15a97efdd363f6934585a76d4 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 21:05:08 +0900 Subject: [PATCH 10/14] Use layout at SVD --- src/matrix.rs | 24 +++++++-------------- src/svd.rs | 58 +++++++++++++-------------------------------------- 2 files changed, 23 insertions(+), 59 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index eb1e1131..e34dd451 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -89,23 +89,15 @@ impl Matrix for Array ImplNorm::norm_f(m, n, self.clone().into_raw_vec()) } fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> { - let strides = self.strides(); - let (m, n) = if strides[0] > strides[1] { - self.size() - } else { - let (n, m) = self.size(); - (m, n) - }; - let (u, s, vt) = try!(ImplSVD::svd(m, n, self.clone().into_raw_vec())); + let (n, m) = self.size(); + let layout = self.layout()?; + let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?; let sv = Array::from_vec(s); - if strides[0] > strides[1] { - let ua = Array::from_vec(u).into_shape((n, n)).unwrap(); - let va = Array::from_vec(vt).into_shape((m, m)).unwrap(); - Ok((va, sv, ua)) - } else { - let ua = Array::from_vec(u).into_shape((n, n)).unwrap().reversed_axes(); - let va = Array::from_vec(vt).into_shape((m, m)).unwrap().reversed_axes(); - Ok((ua, sv, va)) + let ua = Array::from_vec(u).into_shape((n, n)).unwrap(); + let va = Array::from_vec(vt).into_shape((m, m)).unwrap(); + match layout { + Layout::RowMajor => Ok((ua, sv, va)), + Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())), } } fn qr(self) -> Result<(Self, Self), LinalgError> { diff --git a/src/svd.rs b/src/svd.rs index 1386f417..ec2d90d9 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -1,65 +1,37 @@ //! Implement SVD -use lapack::fortran::*; +use std::cmp::min; +use lapack::c::*; use num_traits::Zero; use error::LapackError; pub trait ImplSVD: Sized { - fn svd(n: usize, m: usize, mut a: Vec) -> Result<(Vec, Vec, Vec), LapackError>; + fn svd(layout: Layout, + n: usize, + m: usize, + mut a: Vec) + -> Result<(Vec, Vec, Vec), LapackError>; } macro_rules! impl_svd { ($scalar:ty, $gesvd:path) => { impl ImplSVD for $scalar { - fn svd(n: usize, - m: usize, - mut a: Vec) - -> Result<(Vec, Vec, Vec), LapackError> { - let mut info = 0; + fn svd(layout: Layout, n: usize, m: usize, mut a: Vec) -> Result<(Vec, Vec, Vec), LapackError> { + let k = min(n, m); let n = n as i32; let m = m as i32; - let lda = m; + let lda = match layout { + Layout::RowMajor => n, + Layout::ColumnMajor => m, + }; let ldu = m; let ldvt = n; - let lwork = -1; - let lw_default = 1000; let mut u = vec![Self::zero(); (ldu * m) as usize]; let mut vt = vec![Self::zero(); (ldvt * n) as usize]; let mut s = vec![Self::zero(); n as usize]; - let mut work = vec![Self::zero(); lw_default]; - $gesvd('A' as u8, - 'A' as u8, - m, - n, - &mut a, - lda, - &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - &mut work, - lwork, - &mut info); // calc optimal work - let lwork = work[0] as i32; - if lwork > lw_default as i32 { - work = vec![Self::zero(); lwork as usize]; - } - $gesvd('A' as u8, - 'A' as u8, - m, - n, - &mut a, - lda, - &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - &mut work, - lwork, - &mut info); + let mut superb = vec![Self::zero(); k-2]; + let info = $gesvd(layout, 'A' as u8, 'A' as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb); if info == 0 { Ok((u, s, vt)) } else { From eb2ad98eb1753646643492229fd18c2537e6d414 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 21:22:39 +0900 Subject: [PATCH 11/14] Use layout in QR factorization --- src/matrix.rs | 9 ++-- src/qr.rs | 122 +++++++++++--------------------------------------- 2 files changed, 29 insertions(+), 102 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index e34dd451..a2f076ce 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -96,7 +96,7 @@ impl Matrix for Array let ua = Array::from_vec(u).into_shape((n, n)).unwrap(); let va = Array::from_vec(vt).into_shape((m, m)).unwrap(); match layout { - Layout::RowMajor => Ok((ua, sv, va)), + Layout::RowMajor => Ok((ua, sv, va)), Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())), } } @@ -104,11 +104,8 @@ impl Matrix for Array let (n, m) = self.size(); let strides = self.strides(); let k = min(n, m); - let (q, r) = if strides[0] < strides[1] { - try!(ImplQR::qr(m, n, self.clone().into_raw_vec())) - } else { - try!(ImplQR::lq(n, m, self.clone().into_raw_vec())) - }; + let layout = self.layout()?; + let (q, r) = try!(ImplQR::qr(layout, m, n, self.clone().into_raw_vec())); let (qa, ra) = if strides[0] < strides[1] { (Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(), Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes()) diff --git a/src/qr.rs b/src/qr.rs index 055bb400..1ee864e5 100644 --- a/src/qr.rs +++ b/src/qr.rs @@ -1,111 +1,41 @@ //! Implement QR decomposition use std::cmp::min; -use lapack::fortran::*; +use lapack::c::*; use num_traits::Zero; use error::LapackError; pub trait ImplQR: Sized { - fn qr(n: usize, m: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError>; - fn lq(n: usize, m: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError>; + fn qr(layout: Layout, n: usize, m: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError>; } macro_rules! impl_qr { - ($geqrf:path, $orgqr:path, $gelqf:path, $orglq:path) => { -// XXX These codes are most same, but the argument of $orgqr and $orglq are different! -fn qr(n: usize, m: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError> { - let n = n as i32; - let m = m as i32; - let mut info = 0; - let k = min(m, n); - let lda = m; - let lw_default = 1000; - let mut tau = vec![Self::zero(); k as usize]; - let mut work = vec![Self::zero(); lw_default]; -// estimate lwork - $geqrf(m, n, &mut a, lda, &mut tau, &mut work, -1, &mut info); - let lwork_r = work[0] as i32; - if lwork_r > lw_default as i32 { - work = vec![Self::zero(); lwork_r as usize]; - } -// calc R - $geqrf(m, n, &mut a, lda, &mut tau, &mut work, lwork_r, &mut info); - if info != 0 { - return Err(From::from(info)); - } - let r = a.clone(); -// re-estimate lwork - $orgqr(m, k, k, &mut a, lda, &mut tau, &mut work, -1, &mut info); - let lwork_q = work[0] as i32; - if lwork_q > lwork_r { - work = vec![Self::zero(); lwork_q as usize]; - } -// calc Q - $orgqr(m, - k, - k, - &mut a, - lda, - &mut tau, - &mut work, - lwork_q, - &mut info); - if info == 0 { - Ok((a, r)) - } else { - Err(From::from(info)) - } -} -fn lq(n: usize, m: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError> { - let n = n as i32; - let m = m as i32; - let mut info = 0; - let k = min(m, n); - let lda = m; - let lw_default = 1000; - let mut tau = vec![Self::zero(); k as usize]; - let mut work = vec![Self::zero(); lw_default]; -// estimate lwork - $gelqf(m, n, &mut a, lda, &mut tau, &mut work, -1, &mut info); - let lwork_r = work[0] as i32; - if lwork_r > lw_default as i32 { - work = vec![Self::zero(); lwork_r as usize]; - } -// calc R - $gelqf(m, n, &mut a, lda, &mut tau, &mut work, lwork_r, &mut info); - if info != 0 { - return Err(From::from(info)); - } - let r = a.clone(); -// re-estimate lwork - $orglq(k, n, k, &mut a, lda, &mut tau, &mut work, -1, &mut info); - let lwork_q = work[0] as i32; - if lwork_q > lwork_r { - work = vec![Self::zero(); lwork_q as usize]; - } -// calc Q - $orglq(k, - n, - k, - &mut a, - lda, - &mut tau, - &mut work, - lwork_q, - &mut info); - if info == 0 { - Ok((a, r)) - } else { - Err(From::from(info)) + ($scalar:ty, $geqrf:path, $orgqr:path) => { +impl ImplQR for $scalar { + fn qr(layout: Layout, n: usize, m: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError> { + let n = n as i32; + let m = m as i32; + let k = min(m, n); + let lda = match layout { + Layout::ColumnMajor => m, + Layout::RowMajor => n, + }; + let mut tau = vec![Self::zero(); k as usize]; + let info = $geqrf(layout, m, n, &mut a, lda, &mut tau); + if info != 0 { + return Err(From::from(info)); + } + let r = a.clone(); + let info = $orgqr(layout, m, k, k, &mut a, lda, &mut tau); + if info == 0 { + Ok((a, r)) + } else { + Err(From::from(info)) + } } } }} // endmacro -impl ImplQR for f64 { - impl_qr!(dgeqrf, dorgqr, dgelqf, dorglq); -} - -impl ImplQR for f32 { - impl_qr!(sgeqrf, sorgqr, sgelqf, sorglq); -} +impl_qr!(f64, dgeqrf, dorgqr); +impl_qr!(f32, sgeqrf, sorgqr); From 3f0b26208171bda9b4b10590af057e2249367f70 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 21:25:47 +0900 Subject: [PATCH 12/14] Remove debug code in LU --- src/matrix.rs | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index a2f076ce..fd9d683e 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -133,21 +133,12 @@ impl Matrix for Array } fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { let (n, m) = self.size(); - println!("n={}, m={}", n, m); let k = min(n, m); - let (p, mut a) = match self.layout()? { - Layout::ColumnMajor => { - println!("ColumnMajor"); - let (p, l) = ImplSolve::lu(self.layout()?, n, m, self.clone().into_raw_vec())?; - (p, Array::from_vec(l).into_shape((m, n)).unwrap().reversed_axes()) - } - Layout::RowMajor => { - println!("RowMajor"); - let (p, l) = ImplSolve::lu(self.layout()?, n, m, self.clone().into_raw_vec())?; - (p, Array::from_vec(l).into_shape((n, m)).unwrap()) - } + let (p, l) = ImplSolve::lu(self.layout()?, n, m, self.clone().into_raw_vec())?; + let mut a = match self.layout()? { + Layout::ColumnMajor => Array::from_vec(l).into_shape((m, n)).unwrap().reversed_axes(), + Layout::RowMajor => Array::from_vec(l).into_shape((n, m)).unwrap(), }; - println!("a (after LU) = \n{:?}", &a); let mut lm = Array::zeros((n, k)); for ((i, j), val) in lm.indexed_iter_mut() { if i > j { @@ -166,7 +157,6 @@ impl Matrix for Array } else { a }; - println!("am = \n{:?}", am); Ok((p, lm, am)) } fn permutate(&mut self, ipiv: &Self::Permutator) { From d5fa0bfea97a65806c218de420c15d8a66c3df04 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 21:29:52 +0900 Subject: [PATCH 13/14] Use ColumnMajor (fixed) in norm_* --- src/norm.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/norm.rs b/src/norm.rs index 6ca4920b..7b4d94a7 100644 --- a/src/norm.rs +++ b/src/norm.rs @@ -1,7 +1,6 @@ //! Implement Norms for matrices -use lapack::fortran::*; -use num_traits::Zero; +use lapack::c::*; pub trait ImplNorm: Sized { fn norm_1(m: usize, n: usize, mut a: Vec) -> Self; @@ -13,16 +12,13 @@ macro_rules! impl_norm { ($scalar:ty, $lange:path) => { impl ImplNorm for $scalar { fn norm_1(m: usize, n: usize, mut a: Vec) -> Self { - let mut work = Vec::::new(); - $lange(b'o', m as i32, n as i32, &mut a, m as i32, &mut work) + $lange(Layout::ColumnMajor, b'o', m as i32, n as i32, &mut a, m as i32) } fn norm_i(m: usize, n: usize, mut a: Vec) -> Self { - let mut work = vec![Self::zero(); m]; - $lange(b'i', m as i32, n as i32, &mut a, m as i32, &mut work) + $lange(Layout::ColumnMajor, b'i', m as i32, n as i32, &mut a, m as i32) } fn norm_f(m: usize, n: usize, mut a: Vec) -> Self { - let mut work = Vec::::new(); - $lange(b'f', m as i32, n as i32, &mut a, m as i32, &mut work) + $lange(Layout::ColumnMajor, b'f', m as i32, n as i32, &mut a, m as i32) } } }} // end macro_rules From 43bc59096aab753b6f062dc468133464287b14e3 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Nov 2016 21:35:02 +0900 Subject: [PATCH 14/14] try! -> ? --- .gitignore | 3 +++ rustfmt.toml | 1 + src/hermite.rs | 8 ++++---- src/matrix.rs | 2 +- src/square.rs | 4 ++-- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index cb14a420..3151ac6a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock Cargo.lock + +# cargo fmt +*.bk diff --git a/rustfmt.toml b/rustfmt.toml index 75306517..113fecc8 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1,2 @@ max_width = 120 +use_try_shorthand = true diff --git a/src/hermite.rs b/src/hermite.rs index 1b24bf67..eb2ed74d 100644 --- a/src/hermite.rs +++ b/src/hermite.rs @@ -29,7 +29,7 @@ impl HermiteMatrix for Array where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + ImplCholesky + LinalgScalar + Float + Debug { fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> { - try!(self.check_square()); + self.check_square()?; let layout = self.layout()?; let (rows, cols) = self.size(); let (w, a) = ImplEigh::eigh(layout, rows, self.into_raw_vec())?; @@ -42,7 +42,7 @@ impl HermiteMatrix for Array } fn ssqrt(self) -> Result { let (n, _) = self.size(); - let (e, v) = try!(self.eigh()); + let (e, v) = self.eigh()?; let mut res = Array::zeros((n, n)); for i in 0..n { for j in 0..n { @@ -52,10 +52,10 @@ impl HermiteMatrix for Array Ok(v.dot(&res)) } fn cholesky(self) -> Result { - try!(self.check_square()); + self.check_square()?; let (n, _) = self.size(); let layout = self.layout()?; - let a = try!(ImplCholesky::cholesky(layout, n, self.into_raw_vec())); + let a = ImplCholesky::cholesky(layout, n, self.into_raw_vec())?; let mut c = match layout { Layout::RowMajor => Array::from_vec(a).into_shape((n, n)).unwrap(), Layout::ColumnMajor => Array::from_vec(a).into_shape((n, n)).unwrap().reversed_axes(), diff --git a/src/matrix.rs b/src/matrix.rs index fd9d683e..42157780 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -105,7 +105,7 @@ impl Matrix for Array let strides = self.strides(); let k = min(n, m); let layout = self.layout()?; - let (q, r) = try!(ImplQR::qr(layout, m, n, self.clone().into_raw_vec())); + let (q, r) = ImplQR::qr(layout, m, n, self.clone().into_raw_vec())?; let (qa, ra) = if strides[0] < strides[1] { (Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(), Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes()) diff --git a/src/square.rs b/src/square.rs index 07067790..81887f5d 100644 --- a/src/square.rs +++ b/src/square.rs @@ -41,7 +41,7 @@ impl SquareMatrix for Array where A: ImplQR + ImplNorm + ImplSVD + ImplSolve + LinalgScalar + Float + Debug { fn inv(self) -> Result { - try!(self.check_square()); + self.check_square()?; let (n, _) = self.size(); let layout = self.layout()?; let a = ImplSolve::inv(layout, n, self.into_raw_vec())?; @@ -52,7 +52,7 @@ impl SquareMatrix for Array } } fn trace(&self) -> Result { - try!(self.check_square()); + self.check_square()?; let (n, _) = self.size(); Ok((0..n).fold(A::zero(), |sum, i| sum + self[(i, i)])) }