Skip to content

Commit

Permalink
Merge Tridiagonal_ into Lapack
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Oct 4, 2022
1 parent d9c52a2 commit 4913818
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 148 deletions.
53 changes: 42 additions & 11 deletions lax/src/lib.rs
Expand Up @@ -84,14 +84,14 @@ extern crate openblas_src as _src;
#[cfg(any(feature = "netlib-system", feature = "netlib-static"))]
extern crate netlib_src as _src;

pub mod error;
pub mod flags;
pub mod layout;

pub mod alloc;
pub mod cholesky;
pub mod eig;
pub mod eigh;
pub mod eigh_generalized;
pub mod error;
pub mod flags;
pub mod layout;
pub mod least_squares;
pub mod opnorm;
pub mod qr;
Expand All @@ -101,16 +101,12 @@ pub mod solveh;
pub mod svd;
pub mod svddc;
pub mod triangular;
pub mod tridiagonal;

mod alloc;
mod tridiagonal;

pub use self::cholesky::*;
pub use self::flags::*;
pub use self::least_squares::LeastSquaresOwned;
pub use self::opnorm::*;
pub use self::svd::{SvdOwned, SvdRef};
pub use self::tridiagonal::*;
pub use self::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal};

use self::{alloc::*, error::*, layout::*};
use cauchy::*;
Expand All @@ -120,7 +116,7 @@ pub type Pivot = Vec<i32>;

#[cfg_attr(doc, katexit::katexit)]
/// Trait for primitive types which implements LAPACK subroutines
pub trait Lapack: Tridiagonal_ {
pub trait Lapack: Scalar {
/// Compute right eigenvalue and eigenvectors for a general matrix
fn eig(
calc_v: bool,
Expand Down Expand Up @@ -306,6 +302,19 @@ pub trait Lapack: Tridiagonal_ {
a: &[Self],
b: &mut [Self],
) -> Result<()>;

/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
/// partial pivoting with row interchanges.
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;

fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;

fn solve_tridiagonal(
lu: &LUFactorizedTridiagonal<Self>,
bl: MatrixLayout,
t: Transpose,
b: &mut [Self],
) -> Result<()>;
}

macro_rules! impl_lapack {
Expand Down Expand Up @@ -491,6 +500,28 @@ macro_rules! impl_lapack {
use triangular::*;
SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b)
}

fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
use tridiagonal::*;
let work = LuTridiagonalWork::<$s>::new(a.l);
work.eval(a)
}

fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
use tridiagonal::*;
let mut work = RcondTridiagonalWork::<$s>::new(lu.a.l);
work.calc(lu)
}

fn solve_tridiagonal(
lu: &LUFactorizedTridiagonal<Self>,
bl: MatrixLayout,
t: Transpose,
b: &mut [Self],
) -> Result<()> {
use tridiagonal::*;
SolveTridiagonalImpl::solve_tridiagonal(lu, bl, t, b)
}
}
};
}
Expand Down
137 changes: 0 additions & 137 deletions lax/src/tridiagonal/mod.rs
Expand Up @@ -10,140 +10,3 @@ pub use lu::*;
pub use matrix::*;
pub use rcond::*;
pub use solve::*;

use crate::{error::*, layout::*, *};
use cauchy::*;
use num_traits::Zero;

/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
pub trait Tridiagonal_: Scalar + Sized {
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
/// partial pivoting with row interchanges.
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;

fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;

fn solve_tridiagonal(
lu: &LUFactorizedTridiagonal<Self>,
bl: MatrixLayout,
t: Transpose,
b: &mut [Self],
) -> Result<()>;
}

macro_rules! impl_tridiagonal {
(@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork);
};
(@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, );
};
(@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => {
impl Tridiagonal_ for $scalar {
fn lu_tridiagonal(mut a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
let (n, _) = a.l.size();
let mut du2 = vec_uninit( (n - 2) as usize);
let mut ipiv = vec_uninit( n as usize);
// We have to calc one-norm before LU factorization
let a_opnorm_one = a.opnorm_one();
let mut info = 0;
unsafe {
$gttrf(
&n,
AsPtr::as_mut_ptr(&mut a.dl),
AsPtr::as_mut_ptr(&mut a.d),
AsPtr::as_mut_ptr(&mut a.du),
AsPtr::as_mut_ptr(&mut du2),
AsPtr::as_mut_ptr(&mut ipiv),
&mut info,
)
};
info.as_lapack_result()?;
let du2 = unsafe { du2.assume_init() };
let ipiv = unsafe { ipiv.assume_init() };
Ok(LUFactorizedTridiagonal {
a,
du2,
ipiv,
a_opnorm_one,
})
}

fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
let (n, _) = lu.a.l.size();
let ipiv = &lu.ipiv;
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(2 * n as usize);
$(
let mut $iwork: Vec<MaybeUninit<i32>> = vec_uninit(n as usize);
)*
let mut rcond = Self::Real::zero();
let mut info = 0;
unsafe {
$gtcon(
NormType::One.as_ptr(),
&n,
AsPtr::as_ptr(&lu.a.dl),
AsPtr::as_ptr(&lu.a.d),
AsPtr::as_ptr(&lu.a.du),
AsPtr::as_ptr(&lu.du2),
ipiv.as_ptr(),
&lu.a_opnorm_one,
&mut rcond,
AsPtr::as_mut_ptr(&mut work),
$(AsPtr::as_mut_ptr(&mut $iwork),)*
&mut info,
);
}
info.as_lapack_result()?;
Ok(rcond)
}

fn solve_tridiagonal(
lu: &LUFactorizedTridiagonal<Self>,
b_layout: MatrixLayout,
t: Transpose,
b: &mut [Self],
) -> Result<()> {
let (n, _) = lu.a.l.size();
let ipiv = &lu.ipiv;
// Transpose if b is C-continuous
let mut b_t = None;
let b_layout = match b_layout {
MatrixLayout::C { .. } => {
let (layout, t) = transpose(b_layout, b);
b_t = Some(t);
layout
}
MatrixLayout::F { .. } => b_layout,
};
let (ldb, nrhs) = b_layout.size();
let mut info = 0;
unsafe {
$gttrs(
t.as_ptr(),
&n,
&nrhs,
AsPtr::as_ptr(&lu.a.dl),
AsPtr::as_ptr(&lu.a.d),
AsPtr::as_ptr(&lu.a.du),
AsPtr::as_ptr(&lu.du2),
ipiv.as_ptr(),
AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
&ldb,
&mut info,
);
}
info.as_lapack_result()?;
if let Some(b_t) = b_t {
transpose_over(b_layout, &b_t, b);
}
Ok(())
}
}
};
} // impl_tridiagonal!

impl_tridiagonal!(@real, f64, lapack_sys::dgttrf_, lapack_sys::dgtcon_, lapack_sys::dgttrs_);
impl_tridiagonal!(@real, f32, lapack_sys::sgttrf_, lapack_sys::sgtcon_, lapack_sys::sgttrs_);
impl_tridiagonal!(@complex, c64, lapack_sys::zgttrf_, lapack_sys::zgtcon_, lapack_sys::zgttrs_);
impl_tridiagonal!(@complex, c32, lapack_sys::cgttrf_, lapack_sys::cgtcon_, lapack_sys::cgttrs_);

0 comments on commit 4913818

Please sign in to comment.