Skip to content

Commit

Permalink
Merge pull request #331 from rust-ndarray/revise-flag-naming
Browse files Browse the repository at this point in the history
Revise enum namings
  • Loading branch information
termoshtt committed Sep 3, 2022
2 parents 08aae3b + 86c61c3 commit a4b3118
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 191 deletions.
12 changes: 6 additions & 6 deletions lax/src/eig.rs
Expand Up @@ -35,11 +35,11 @@ macro_rules! impl_eig_complex {
// eigenvalues are the eigenvalues computed with `A`.
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
}
} else {
(EigenVectorFlag::Not, EigenVectorFlag::Not)
(JobEv::None, JobEv::None)
};
let mut eigs: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
let mut rwork: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(2 * n as usize) };
Expand Down Expand Up @@ -143,11 +143,11 @@ macro_rules! impl_eig_real {
// `sgeev`/`dgeev`.
let (jobvl, jobvr) = if calc_v {
match l {
MatrixLayout::C { .. } => (EigenVectorFlag::Calc, EigenVectorFlag::Not),
MatrixLayout::F { .. } => (EigenVectorFlag::Not, EigenVectorFlag::Calc),
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
}
} else {
(EigenVectorFlag::Not, EigenVectorFlag::Not)
(JobEv::None, JobEv::None)
};
let mut eig_re: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
let mut eig_im: Vec<MaybeUninit<Self>> = unsafe { vec_uninit(n as usize) };
Expand Down
4 changes: 2 additions & 2 deletions lax/src/eigh.rs
Expand Up @@ -41,7 +41,7 @@ macro_rules! impl_eigh {
) -> Result<Vec<Self::Real>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not };
let jobz = if calc_v { JobEv::All } else { JobEv::None };
let mut eigs: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(n as usize) };

$(
Expand Down Expand Up @@ -100,7 +100,7 @@ macro_rules! impl_eigh {
) -> Result<Vec<Self::Real>> {
assert_eq!(layout.len(), layout.lda());
let n = layout.len();
let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not };
let jobz = if calc_v { JobEv::All } else { JobEv::None };
let mut eigs: Vec<MaybeUninit<Self::Real>> = unsafe { vec_uninit(n as usize) };

$(
Expand Down
137 changes: 137 additions & 0 deletions lax/src/flags.rs
@@ -0,0 +1,137 @@
//! Charactor flags, e.g. `'T'`, used in LAPACK API

/// Upper/Lower specification for seveal usages
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum UPLO {
Upper = b'U',
Lower = b'L',
}

impl UPLO {
pub fn t(self) -> Self {
match self {
UPLO::Upper => UPLO::Lower,
UPLO::Lower => UPLO::Upper,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const UPLO as *const i8
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Transpose {
No = b'N',
Transpose = b'T',
Hermite = b'C',
}

impl Transpose {
/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const Transpose as *const i8
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum NormType {
One = b'O',
Infinity = b'I',
Frobenius = b'F',
}

impl NormType {
pub fn transpose(self) -> Self {
match self {
NormType::One => NormType::Infinity,
NormType::Infinity => NormType::One,
NormType::Frobenius => NormType::Frobenius,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const NormType as *const i8
}
}

/// Flag for calculating eigenvectors or not
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum JobEv {
/// Calculate eigenvectors in addition to eigenvalues
All = b'V',
/// Do not calculate eigenvectors. Only calculate eigenvalues.
None = b'N',
}

impl JobEv {
pub fn is_calc(&self) -> bool {
match self {
JobEv::All => true,
JobEv::None => false,
}
}

pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
if self.is_calc() {
Some(f())
} else {
None
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const JobEv as *const i8
}
}

/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
///
/// For an input array of shape *m*×*n*, the following are computed:
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum JobSvd {
/// All *m* columns of *U* and all *n* rows of *V*ᵀ.
All = b'A',
/// The first min(*m*,*n*) columns of *U* and the first min(*m*,*n*) rows of *V*ᵀ.
Some = b'S',
/// No columns of *U* or rows of *V*ᵀ.
None = b'N',
}

impl JobSvd {
pub fn from_bool(calc_uv: bool) -> Self {
if calc_uv {
JobSvd::All
} else {
JobSvd::None
}
}

pub fn as_ptr(&self) -> *const i8 {
self as *const JobSvd as *const i8
}
}

/// Specify whether input triangular matrix is unit or not
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Diag {
/// Unit triangular matrix, i.e. all diagonal elements of the matrix are `1`
Unit = b'U',
/// Non-unit triangular matrix. Its diagonal elements may be different from `1`
NonUnit = b'N',
}

impl Diag {
pub fn as_ptr(&self) -> *const i8 {
self as *const Diag as *const i8
}
}
92 changes: 2 additions & 90 deletions lax/src/lib.rs
Expand Up @@ -69,6 +69,7 @@ extern crate openblas_src as _src;
extern crate netlib_src as _src;

pub mod error;
pub mod flags;
pub mod layout;

mod cholesky;
Expand All @@ -88,6 +89,7 @@ mod tridiagonal;
pub use self::cholesky::*;
pub use self::eig::*;
pub use self::eigh::*;
pub use self::flags::*;
pub use self::least_squares::*;
pub use self::opnorm::*;
pub use self::qr::*;
Expand Down Expand Up @@ -173,96 +175,6 @@ impl<T> VecAssumeInit for Vec<MaybeUninit<T>> {
}
}

/// Upper/Lower specification for seveal usages
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum UPLO {
Upper = b'U',
Lower = b'L',
}

impl UPLO {
pub fn t(self) -> Self {
match self {
UPLO::Upper => UPLO::Lower,
UPLO::Lower => UPLO::Upper,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const UPLO as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum Transpose {
No = b'N',
Transpose = b'T',
Hermite = b'C',
}

impl Transpose {
/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const Transpose as *const i8
}
}

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum NormType {
One = b'O',
Infinity = b'I',
Frobenius = b'F',
}

impl NormType {
pub fn transpose(self) -> Self {
match self {
NormType::One => NormType::Infinity,
NormType::Infinity => NormType::One,
NormType::Frobenius => NormType::Frobenius,
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const NormType as *const i8
}
}

/// Flag for calculating eigenvectors or not
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum EigenVectorFlag {
Calc = b'V',
Not = b'N',
}

impl EigenVectorFlag {
pub fn is_calc(&self) -> bool {
match self {
EigenVectorFlag::Calc => true,
EigenVectorFlag::Not => false,
}
}

pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
if self.is_calc() {
Some(f())
} else {
None
}
}

/// To use Fortran LAPACK API in lapack-sys crate
pub fn as_ptr(&self) -> *const i8 {
self as *const EigenVectorFlag as *const i8
}
}

/// Create a vector without initialization
///
/// Safety
Expand Down
43 changes: 11 additions & 32 deletions lax/src/svd.rs
@@ -1,32 +1,9 @@
//! Singular-value decomposition

use crate::{error::*, layout::MatrixLayout, *};
use super::{error::*, layout::*, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};

#[repr(u8)]
#[derive(Debug, Copy, Clone)]
enum FlagSVD {
All = b'A',
// OverWrite = b'O',
// Separately = b'S',
No = b'N',
}

impl FlagSVD {
fn from_bool(calc_uv: bool) -> Self {
if calc_uv {
FlagSVD::All
} else {
FlagSVD::No
}
}

fn as_ptr(&self) -> *const i8 {
self as *const FlagSVD as *const i8
}
}

/// Result of SVD
pub struct SVDOutput<A: Scalar> {
/// diagonal values
Expand Down Expand Up @@ -55,24 +32,26 @@ macro_rules! impl_svd {
impl SVD_ for $scalar {
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result<SVDOutput<Self>> {
let ju = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
};
let jvt = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
};

let m = l.lda();
let mut u = match ju {
FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
FlagSVD::No => None,
JobSvd::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
JobSvd::None => None,
_ => unimplemented!("SVD with partial vector output is not supported yet")
};

let n = l.len();
let mut vt = match jvt {
FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
FlagSVD::No => None,
JobSvd::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
JobSvd::None => None,
_ => unimplemented!("SVD with partial vector output is not supported yet")
};

let k = std::cmp::min(m, n);
Expand Down

0 comments on commit a4b3118

Please sign in to comment.