From 50b796c143ad5b4bc2551708d7d540fd0cd460a7 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Aug 2017 16:26:57 +0900 Subject: [PATCH 1/2] Add Solve trait and drop Transpose from solve (split to functions) --- examples/solve.rs | 2 +- src/solve.rs | 71 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/examples/solve.rs b/examples/solve.rs index 2bcba2d8..2abc07e2 100644 --- a/examples/solve.rs +++ b/examples/solve.rs @@ -11,7 +11,7 @@ fn factorize() -> Result<(), error::LinalgError> { let f = a.factorize_into()?; // LU factorize A (A is consumed) for _ in 0..10 { let b: Array1 = random(3); - let x = f.solve(Transpose::No, b)?; // solve Ax=b using factorized L, U + let _x = f.solve(&b)?; // solve Ax=b using factorized L, U } Ok(()) } diff --git a/src/solve.rs b/src/solve.rs index 52919b63..71e142ee 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -9,24 +9,89 @@ use super::types::*; pub use lapack_traits::{Pivot, Transpose}; +pub trait Solve { + fn solve>(&self, a: &ArrayBase) -> Result> { + let mut a = replicate(a); + self.solve_mut(&mut a)?; + Ok(a) + } + fn solve_into>(&self, mut a: ArrayBase) -> Result> { + self.solve_mut(&mut a)?; + Ok(a) + } + fn solve_mut<'a, S: DataMut>(&self, &'a mut ArrayBase) -> Result<&'a mut ArrayBase>; + + fn solve_t>(&self, a: &ArrayBase) -> Result> { + let mut a = replicate(a); + self.solve_t_mut(&mut a)?; + Ok(a) + } + fn solve_t_into>(&self, mut a: ArrayBase) -> Result> { + self.solve_t_mut(&mut a)?; + Ok(a) + } + fn solve_t_mut<'a, S: DataMut>(&self, &'a mut ArrayBase) -> Result<&'a mut ArrayBase>; + + fn solve_h>(&self, a: &ArrayBase) -> Result> { + let mut a = replicate(a); + self.solve_h_mut(&mut a)?; + Ok(a) + } + fn solve_h_into>(&self, mut a: ArrayBase) -> Result> { + self.solve_h_mut(&mut a)?; + Ok(a) + } + fn solve_h_mut<'a, S: DataMut>(&self, &'a mut ArrayBase) -> Result<&'a mut ArrayBase>; +} + pub struct Factorized { pub a: ArrayBase, pub ipiv: Pivot, } -impl Factorized +impl Solve for Factorized where A: Scalar, S: Data, { - pub fn solve(&self, t: Transpose, mut rhs: ArrayBase) -> Result> + fn solve_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve( + self.a.square_layout()?, + Transpose::No, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } + fn solve_t_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solve( + self.a.square_layout()?, + Transpose::Transpose, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } + fn solve_h_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> where Sb: DataMut, { unsafe { A::solve( self.a.square_layout()?, - t, + Transpose::Hermite, self.a.as_allocated()?, &self.ipiv, rhs.as_slice_mut().unwrap(), From e785cdfffbdb68369d91880358b4d02797e363ae Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 22 Aug 2017 16:29:14 +0900 Subject: [PATCH 2/2] Add utility for solve: factorize and solve in one function --- examples/solve.rs | 11 ++++++++++- src/solve.rs | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/examples/solve.rs b/examples/solve.rs index 2abc07e2..46f53dec 100644 --- a/examples/solve.rs +++ b/examples/solve.rs @@ -5,17 +5,26 @@ extern crate ndarray_linalg; use ndarray::*; use ndarray_linalg::*; +// Solve `Ax=b` +fn solve() -> Result<(), error::LinalgError> { + let a: Array2 = random((3, 3)); + let b: Array1 = random(3); + let _x = a.solve(&b)?; + Ok(()) +} + // Solve `Ax=b` for many b with fixed A fn factorize() -> Result<(), error::LinalgError> { let a: Array2 = random((3, 3)); let f = a.factorize_into()?; // LU factorize A (A is consumed) for _ in 0..10 { let b: Array1 = random(3); - let _x = f.solve(&b)?; // solve Ax=b using factorized L, U + let _x = f.solve_into(b)?; // solve Ax=b using factorized L, U } Ok(()) } fn main() { + solve().unwrap(); factorize().unwrap(); } diff --git a/src/solve.rs b/src/solve.rs index 71e142ee..3ff12253 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -101,6 +101,34 @@ where } } +impl Solve for ArrayBase +where + A: Scalar, + S: Data, +{ + fn solve_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize()?; + f.solve_mut(rhs) + } + fn solve_t_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize()?; + f.solve_t_mut(rhs) + } + fn solve_h_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize()?; + f.solve_h_mut(rhs) + } +} + impl Factorized where A: Scalar,