Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
executable file 559 lines (510 sloc) 16.7 KB
use std::cmp;
use std::ops::Mul;
use num;
use num::{ Float, Signed };
use ApproxEq;
use Matrix;
use internalutil::{alloc_dirty_vec, hypot};
/// Singular Value Decomposition.
///
/// Ported from JAMA (with changes).
///
/// For an m-by-n matrix A, the singular value decomposition is
/// an m-by-m orthogonal matrix U, an m-by-n block diagonal matrix S, and
/// an n-by-n orthogonal matrix V so that A = U*S*V'.
///
/// The singular values, sigma[k] = S[k][k], are ordered so that
/// sigma[0] >= sigma[1] >= ... >= sigma[n-1].
///
/// The singular value decompostion always exists. The matrix condition number
/// and the effective numerical rank can be computed from this decomposition.
pub struct SVD<T> {
u : Matrix<T>,
s : Matrix<T>,
v : Matrix<T>
}
impl<T : Float + Signed + ApproxEq<T>> SVD<T> {
/// Calculates SVD.
pub fn new(a : &Matrix<T>) -> SVD<T> {
// A = USV'
if a.rows() < a.cols() {
// a' = (usv')' = vs'u'
let svd = SVD::new(&a.t());
return SVD {
u : svd.v.clone(),
s : svd.s.t(),
v : svd.u.clone()
}
}
// Derived from LINPACK code.
// Initialize.
let mut adata = a.get_data().clone();
let m = a.rows();
let n = a.cols();
assert!(m >= n);
let slen = cmp::min(m, n);
let mut sdata : Vec<T> = alloc_dirty_vec(slen);
let ulen = m * m;
let mut udata = alloc_dirty_vec(ulen);
let vlen = n * n;
let mut vdata = alloc_dirty_vec(vlen);
let mut edata = alloc_dirty_vec(n);
let mut workdata : Vec<T> = alloc_dirty_vec(m);
// Reduce A to bidiagonal form, storing the diagonal elements
// in s and the super-diagonal elements in e.
let nct = cmp::min(m - 1, n);
let nrt = cmp::max(0 as isize, cmp::min((n as isize) - 2, m as isize)) as usize;
for k in 0..cmp::max(nct, nrt) {
if k < nct {
// Compute the transformation for the k-th column and
// place the k-th diagonal in s[k].
// Compute 2-norm of k-th column without under/overflow.
sdata[k] = num::zero();
for i in k..m {
sdata[k] = hypot(sdata[k], adata[i * n + k]);
}
if sdata[k] != num::zero() {
if adata[k * n + k] < num::zero() {
sdata[k] = - sdata[k];
}
for i in k..m {
adata[i * n + k] = adata[i * n + k] / sdata[k];
}
adata[k * n + k] = adata[k * n + k] + num::one();
}
sdata[k] = - sdata[k];
}
for j in (k + 1)..n {
if (k < nct) && (sdata[k] != num::zero()) {
// Apply the transformation.
let mut t : T = num::zero();
for i in k..m {
t = t + adata[i * n + k] * adata[i * n + j];
}
t = - t / adata[k * n + k];
for i in k..m {
adata[i * n + j] = adata[i * n + j] + t * adata[i * n + k];
}
}
// Place the k-th row of A into e for the
// subsequent calculation of the row transformation.
edata[j] = adata[k * n + j];
}
if k < nct {
// Place the transformation in U for subsequent back multiplication.
for i in k..m {
udata[i * m + k] = adata[i * n + k];
}
}
if k < nrt {
// Compute the k-th row transformation and place the k-th super-diagonal in e[k].
// Compute 2-norm without under/overflow.
edata[k] = num::zero();
for i in (k + 1)..n {
edata[k] = hypot(edata[k], edata[i]);
}
if edata[k] != num::zero() {
if edata[k + 1] < num::zero() {
edata[k] = - edata[k];
}
for i in (k + 1)..n {
edata[i] = edata[i] / edata[k];
}
edata[k + 1] = edata[k + 1] + num::one();
}
edata[k] = - edata[k];
if (k + 1 < m) && (edata[k] != num::zero()) {
// Apply the transformation.
for i in (k + 1)..m {
workdata[i] = num::zero();
}
for j in (k + 1)..n {
for i in (k + 1)..m {
workdata[i] = workdata[i] + edata[j] * adata[i * n + j];
}
}
for j in (k + 1)..n {
let t = - edata[j] / edata[k + 1];
for i in (k + 1)..m {
adata[i * n + j] = adata[i * n + j] + t * workdata[i];
}
}
}
// Place the transformation in V for subsequent back multiplication.
for i in (k + 1)..n {
vdata[i * n + k] = edata[i];
}
}
}
// Set up the final bidiagonal matrix or order p.
let mut p = cmp::min(n, m + 1);
if nct < n {
sdata[nct] = adata[nct * n + nct];
}
if m < p {
sdata[p - 1] = num::zero();
}
if (nrt + 1) < p {
edata[nrt] = adata[nrt * n + (p - 1)];
}
edata[p - 1] = num::zero();
// Generate U.
for j in nct..m {
for i in 0..m {
udata[i * m + j] = num::zero();
}
udata[j * m + j] = num::one();
}
for k in (0..nct).rev() {
if sdata[k] != num::zero() {
for j in (k + 1)..m {
let mut t : T = num::zero();
for i in k..m {
t = t + udata[i * m + k] * udata[i * m + j];
}
t = - t / udata[k * m + k];
for i in k..m {
udata[i * m + j] = udata[i * m + j] + t * udata[i * m + k];
}
}
for i in k..m {
udata[i * m + k] = - udata[i * m + k];
}
udata[k * m + k] = num::one::<T>() + udata[k * m + k];
for i in 0..k {
udata[(i as usize) * m + k] = num::zero();
}
//let mut i = 0;
//while i < ((k as isize) - 1) {
// i -= 1;
//}
} else {
for i in 0..m {
udata[i * m + k] = num::zero();
}
udata[k * m + k] = num::one();
}
}
// Generate V.
for k in (0..n).rev() {
if (k < nrt) && (edata[k] != num::zero()) {
for j in (k + 1)..n {
let mut t : T = num::zero();
for i in (k + 1)..n {
t = t + vdata[i * n + k] * vdata[i * n + j];
}
t = - t / vdata[(k + 1) * n + k];
for i in (k + 1)..n {
vdata[i * n + j] = vdata[i * n + j] + t * vdata[i * n + k];
}
}
}
for i in 0..n {
vdata[i * n + k] = num::zero();
}
vdata[k * n + k] = num::one();
}
// Main iteration loop for the singular values.
let pp = p - 1;
let eps : T = num::cast(2.0f64.powf(-52.0)).unwrap();
let tiny : T = num::cast(2.0f64.powf(-966.0)).unwrap();
while p > 0 {
// Here is where a test for too many iterations would go.
// This section of the program inspects for
// negligible elements in the s and e arrays. On
// completion the variables kase and k are set as follows.
// kase = 1 if s(p) and e[k-1] are negligible and k<p
// kase = 2 if s(k) is negligible and k<p
// kase = 3 if e[k-1] is negligible, k<p, and
// s(k), ..., s(p) are not negligible (qr step).
// kase = 4 if e(p-1) is negligible (convergence).
let kase;
let mut k = (p as isize) - 2;
while k >= 0 {
if num::abs(edata[k as usize]) <= (tiny + eps * (num::abs(sdata[k as usize]) + num::abs(sdata[(k + 1) as usize]))) {
edata[k as usize] = num::zero();
break;
}
k -= 1;
}
if k == ((p as isize) - 2) {
kase = 4;
} else {
let mut ks = (p as isize) - 1;
while ks > k {
let t = (if ks != (p as isize) { num::abs(edata[ks as usize]) } else { num::zero() })
+ (if ks != (k + 1) { num::abs(edata[(ks - 1) as usize]) } else { num::zero() });
if num::abs(sdata[ks as usize]) <= (tiny + eps * t) {
sdata[ks as usize] = num::zero();
break;
}
ks -= 1;
}
if ks == k {
kase = 3;
} else if ks == ((p as isize) - 1) {
kase = 1;
} else {
kase = 2;
k = ks;
}
}
k += 1;
// Perform the task indicated by kase.
if kase == 1 {
// Deflate negligible s(p).
let mut f = edata[p - 2];
edata[p - 2] = num::zero();
let mut j = (p as isize) - 2;
while j >= k {
let mut t = hypot(sdata[j as usize], f);
let cs = sdata[j as usize] / t;
let sn = f / t;
sdata[j as usize] = t;
if j != k {
f = - sn * edata[(j - 1) as usize];
edata[(j - 1) as usize] = cs * edata[(j - 1) as usize];
}
for i in 0..n {
t = cs * vdata[i * n + (j as usize)] + sn * vdata[i * n + (p - 1)];
vdata[i * n + (p - 1)] = - sn * vdata[i * n + (j as usize)] + cs * vdata[i * n + (p - 1)];
vdata[i * n + (j as usize)] = t;
}
j -= 1;
}
} else if kase == 2 {
// Split at negligible s(k).
let mut f = edata[(k - 1) as usize];
edata[(k - 1) as usize] = num::zero();
for j in k..(p as isize) {
let mut t = hypot(sdata[j as usize], f);
let cs = sdata[j as usize] / t;
let sn = f / t;
sdata[j as usize] = t;
f = - sn * edata[j as usize];
edata[j as usize] = cs * edata[j as usize];
for i in 0..m {
t = cs * udata[i * m + (j as usize)] + sn * udata[i * m + ((k as usize) - 1)];
udata[i * m + ((k as usize) - 1)] = - sn * udata[i * m + (j as usize)] + cs * udata[i * m + ((k as usize) - 1)];
udata[i * m + (j as usize)] = t;
}
}
} else if kase == 3 {
// Perform one qr step.
// Calculate the shift.
let scale = num::abs(sdata[p - 1])
.max(num::abs(sdata[p - 2]))
.max(num::abs(edata[p - 2]))
.max(num::abs(sdata[k as usize]))
.max(num::abs(edata[k as usize]));
let sp = sdata[p - 1] / scale;
let spm1 = sdata[p - 2] / scale;
let epm1 = edata[p - 2] / scale;
let sk = sdata[k as usize] / scale;
let ek = edata[k as usize] / scale;
let b = ((spm1 + sp) * (spm1 - sp) + epm1 * epm1) / num::cast(2.0).unwrap();
let c = (sp * epm1) * (sp * epm1);
let mut shift = num::zero();
if (b != num::zero()) || (c != num::zero()) {
shift = (b * b + c).sqrt();
if b < num::zero() {
shift = - shift;
}
shift = c / (b + shift);
}
let mut f = (sk + sp) * (sk - sp) + shift;
let mut g = sk * ek;
// Chase zeros.
for j in k..((p as isize) - 1) {
let mut t = hypot(f, g);
let mut cs = f / t;
let mut sn = g / t;
if j != k {
edata[(j - 1) as usize] = t;
}
f = cs * sdata[j as usize] + sn * edata[j as usize];
edata[j as usize] = cs * edata[j as usize] - sn * sdata[j as usize];
g = sn * sdata[(j + 1) as usize];
sdata[(j + 1) as usize] = cs * sdata[(j + 1) as usize];
for i in 0..n {
t = cs * vdata[i * n + (j as usize)] + sn * vdata[i * n + ((j as usize) + 1)];
vdata[i * n + ((j as usize) + 1)] = - sn * vdata[i * n + (j as usize)] + cs * vdata[i * n + ((j as usize) + 1)];
vdata[i * n + (j as usize)] = t;
}
t = hypot(f, g);
cs = f / t;
sn = g / t;
sdata[j as usize] = t;
f = cs * edata[j as usize] + sn * sdata[(j + 1) as usize];
sdata[(j + 1) as usize] = - sn * edata[j as usize] + cs * sdata[(j + 1) as usize];
g = sn * edata[(j + 1) as usize];
edata[(j + 1) as usize] = cs * edata[(j + 1) as usize];
if j < ((m as isize) - 1) {
for i in 0..m {
t = cs * udata[i * m + (j as usize)] + sn * udata[i * m + ((j as usize) + 1)];
udata[i * m + ((j as usize) + 1)] = - sn * udata[i * m + (j as usize)] + cs * udata[i * m + ((j as usize) + 1)];
udata[i * m + (j as usize)] = t;
}
}
}
edata[p - 2] = f;
} else if kase == 4 {
// Convergence.
// Make the singular values positive.
if sdata[k as usize] <= num::zero() {
sdata[k as usize] = if sdata[k as usize] < num::zero() { - sdata[k as usize] } else { num::zero() };
for i in 0..(pp + 1) {
vdata[i * n + (k as usize)] = - vdata[i * n + (k as usize)];
}
}
// Order the singular values.
while k < (pp as isize) {
if sdata[k as usize] >= sdata[(k + 1) as usize] {
break;
}
let mut t = sdata[k as usize];
sdata[k as usize] = sdata[(k + 1) as usize];
sdata[(k + 1) as usize] = t;
if k < ((n as isize) - 1) {
for i in 0..n {
t = vdata[i * n + ((k as usize) + 1)];
vdata[i * n + ((k as usize) + 1)] = vdata[i * n + (k as usize)];
vdata[i * n + (k as usize)] = t;
}
}
if k < ((m as isize) - 1) {
for i in 0..m {
t = udata[i * m + ((k as usize) + 1)];
udata[i * m + ((k as usize) + 1)] = udata[i * m + (k as usize)];
udata[i * m + (k as usize)] = t;
}
}
k += 1;
}
p -= 1;
}
}
SVD {
u : Matrix::new(m, m, udata),
s : Matrix::block_diag(m, n, sdata),
v : Matrix::new(n, n, vdata)
}
}
pub fn get_u<'lt>(&'lt self) -> &'lt Matrix<T> {
&self.u
}
pub fn get_s<'lt>(&'lt self) -> &'lt Matrix<T> {
&self.s
}
pub fn get_v<'lt>(&'lt self) -> &'lt Matrix<T> {
&self.v
}
pub fn rank(&self) -> usize {
let eps : T = num::cast(2.0f64.powf(-52.0)).unwrap();
let max_dim : T = num::cast(cmp::max(self.u.rows(), self.v.rows())).unwrap();
let tol = max_dim * self.s.get(0, 0) * eps;
let mut r = 0;
for i in 0..self.s.rows() {
if self.s.get(i, i) > tol {
r += 1;
}
}
r
}
/// Calculates SVD using the direct method. Note that calculating it this way
/// is not numerically stable, so it is mostly useful for testing purposes.
pub fn direct(a : &Matrix<T>) -> SVD<T> {
use EigenDecomposition;
// A = USV'
if a.rows() < a.cols() {
// a' = (usv')' = vs'u'
let svd = SVD::direct(&a.t());
return SVD {
u : svd.v.clone(),
s : svd.s.t(),
v : svd.u.clone()
}
}
// A'A = VS'U'USV'
// = VS'SV'
let ata = a.t().mul(a);
let edc = EigenDecomposition::new(&ata);
let v = edc.get_v();
let eigs = edc.get_real_eigenvalues();
let singular_values : Vec<T> = eigs.iter().map(|&e| e.sqrt()).collect();
// U*S*V' = A
// U*S*V'*V = A*V
// U*S = A*V
// U*S*Sinv = A*V*Sinv
// U = A*V*Sinv
let s_size = singular_values.len();
let s = Matrix::block_diag(s_size, s_size, singular_values);
let s_inv = s.inverse().unwrap();
let (s_aug, s_inv_aug) =
if a.rows() == a.cols() { (s, s_inv) }
else {
(s.cb(&Matrix::zero(a.rows() - a.cols(), s.cols())),
s_inv.cr(&Matrix::zero(s_inv.rows(), a.rows() - a.cols())))
};
let u = a.mul(v).mul(&s_inv_aug);
SVD {
u : u.clone(),
s : s_aug.clone(),
v : v.clone()
}
}
}
#[test]
fn svd_test() {
let a = m!(1.0, 2.0, 3.0; 4.0, 5.0, 6.0; 7.0, 8.0, 9.0);
let svd = SVD::new(&a);
let u = svd.get_u();
let s = svd.get_s();
let v = svd.get_v();
assert!((u * s * v.t()).approx_eq(&a));
}
#[test]
fn svd_test_m_over_n() {
let a = m!(1.0, 2.0, 3.0; 4.0, 5.0, 6.0; 7.0, 8.0, 9.0; 10.0, 11.0, 12.0);
let svd = SVD::new(&a);
let u = svd.get_u();
let s = svd.get_s();
let v = svd.get_v();
assert!((u * s * v.t()).approx_eq(&a));
}
#[test]
fn svd_test_n_over_m() {
let a = m!(1.0, 2.0, 3.0, 4.0; 5.0, 6.0, 7.0, 8.0; 9.0, 10.0, 11.0, 12.0);
let svd = SVD::new(&a);
let u = svd.get_u();
let s = svd.get_s();
let v = svd.get_v();
assert!((u * s * v.t()).approx_eq(&a));
}
#[test]
fn direct_test() {
let a = m!(1.0, 2.0, 3.0; 4.0, 5.0, 6.0; 7.0, 8.0, 9.0);
let svd = SVD::<f64>::direct(&a);
let u = svd.get_u();
let s = svd.get_s();
let v = svd.get_v();
assert!((u * s * v.t()).approx_eq(&a));
}
#[test]
fn direct_test_m_over_n() {
let a = m!(1.0, 2.0, 3.0; 4.0, 5.0, 6.0; 7.0, 8.0, 9.0; 10.0, 11.0, 12.0);
let svd = SVD::<f64>::direct(&a);
let u = svd.get_u();
let s = svd.get_s();
let v = svd.get_v();
assert!((u * s * v.t()).approx_eq(&a));
}
#[test]
fn direct_test_n_over_m() {
let a = m!(1.0, 2.0, 3.0, 4.0; 5.0, 6.0, 7.0, 8.0; 9.0, 10.0, 11.0, 12.0);
let svd = SVD::<f64>::direct(&a);
let u = svd.get_u();
let s = svd.get_s();
let v = svd.get_v();
assert!((u * s * v.t()).approx_eq(&a));
}