Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 117 additions & 22 deletions src/solve.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,49 @@
//! Solve linear problems
//! Solve systems of linear equations and invert matrices
//!
//! # Examples
//!
//! Solve `A * x = b`:
//!
//! ```
//! #[macro_use]
//! extern crate ndarray;
//! extern crate ndarray_linalg;
//!
//! use ndarray::prelude::*;
//! use ndarray_linalg::Solve;
//! # fn main() {
//!
//! let a: Array2<f64> = array![[3., 2., -1.], [2., -2., 4.], [-2., 1., -2.]];
//! let b: Array1<f64> = array![1., -2., 0.];
//! let x = a.solve_into(b).unwrap();
//! assert!(x.all_close(&array![1., -2., -2.], 1e-9));
//!
//! # }
//! ```
//!
//! There are also special functions for solving `A^T * x = b` and
//! `A^H * x = b`.
//!
//! If you are solving multiple systems of linear equations with the same
//! coefficient matrix `A`, it's faster to compute the LU factorization once at
//! the beginning than solving directly using `A`:
//!
//! ```
//! # extern crate ndarray;
//! # extern crate ndarray_linalg;
//! use ndarray::prelude::*;
//! use ndarray_linalg::*;
//! # fn main() {
//!
//! let a: Array2<f64> = random((3, 3));
//! let f = a.factorize_into().unwrap(); // LU factorize A (A is consumed)
//! for _ in 0..10 {
//! let b: Array1<f64> = random(3);
//! let x = f.solve_into(b).unwrap(); // Solve A * x = b using factorized L, U
//! }
//!
//! # }
//! ```

use ndarray::*;

Expand All @@ -9,43 +54,82 @@ use super::types::*;

pub use lapack_traits::{Pivot, Transpose};

/// An interface for solving systems of linear equations.
///
/// There are three groups of methods:
///
/// * `solve*` (normal) methods solve `A * x = b` for `x`.
/// * `solve_t*` (transpose) methods solve `A^T * x = b` for `x`.
/// * `solve_h*` (Hermitian conjugate) methods solve `A^H * x = b` for `x`.
///
/// Within each group, there are three methods that handle ownership differently:
///
/// * `*` methods take a reference to `b` and return `x` as a new array.
/// * `*_into` methods take ownership of `b`, store the result in it, and return it.
/// * `*_mut` methods take a mutable reference to `b` and store the result in that array.
///
/// If you plan to solve many equations with the same `A` matrix but different
/// `b` vectors, it's faster to factor the `A` matrix once using the
/// `Factorize` trait, and then solve using the `Factorized` struct.
pub trait Solve<A: Scalar> {
fn solve<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut a = replicate(a);
self.solve_mut(&mut a)?;
Ok(a)
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_mut(&mut b)?;
Ok(b)
}
fn solve_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
self.solve_mut(&mut a)?;
Ok(a)
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve_into<S: DataMut<Elem = A>>(&self, mut b: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
self.solve_mut(&mut b)?;
Ok(b)
}
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;

fn solve_t<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut a = replicate(a);
self.solve_t_mut(&mut a)?;
Ok(a)
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_t_mut(&mut b)?;
Ok(b)
}
fn solve_t_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
self.solve_t_mut(&mut a)?;
Ok(a)
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve_t_into<S: DataMut<Elem = A>>(&self, mut b: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
self.solve_t_mut(&mut b)?;
Ok(b)
}
/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve_t_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;

fn solve_h<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut a = replicate(a);
self.solve_h_mut(&mut a)?;
Ok(a)
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_h_mut(&mut b)?;
Ok(b)
}
fn solve_h_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
self.solve_h_mut(&mut a)?;
Ok(a)
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve_h_into<S: DataMut<Elem = A>>(&self, mut b: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
self.solve_h_mut(&mut b)?;
Ok(b)
}
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
fn solve_h_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
}

/// Represents the LU factorization of a matrix `A` as `A = P*L*U`.
pub struct Factorized<S: Data> {
/// The factors `L` and `U`; the unit diagonal elements of `L` are not
/// stored.
pub a: ArrayBase<S, Ix2>,
/// The pivot indices that define the permutation matrix `P`.
pub ipiv: Pivot,
}

Expand Down Expand Up @@ -134,6 +218,7 @@ where
A: Scalar,
S: DataMut<Elem = A>,
{
/// Computes the inverse of the factorized matrix.
pub fn into_inverse(mut self) -> Result<ArrayBase<S, Ix2>> {
unsafe {
A::inv(
Expand All @@ -146,11 +231,17 @@ where
}
}

/// An interface for computing LU factorizations of matrix refs.
pub trait Factorize<S: Data> {
/// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
/// matrix.
fn factorize(&self) -> Result<Factorized<S>>;
}

/// An interface for computing LU factorizations of matrices.
pub trait FactorizeInto<S: Data> {
/// Computes the LU factorization `A = P*L*U`, where `P` is a permutation
/// matrix.
fn factorize_into(self) -> Result<Factorized<S>>;
}

Expand Down Expand Up @@ -180,13 +271,17 @@ where
}
}

/// An interface for inverting matrix refs.
pub trait Inverse {
type Output;
/// Computes the inverse of the matrix.
fn inv(&self) -> Result<Self::Output>;
}

/// An interface for inverting matrices.
pub trait InverseInto {
type Output;
/// Computes the inverse of the matrix.
fn inv_into(self) -> Result<Self::Output>;
}

Expand Down