diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index 2919b025..9057df56 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -9,7 +9,7 @@ //! //! Lasso coefficient estimates solve the problem: //! -//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\] +//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\] //! //! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy, //! but is able to solve them with high accuracy with relatively small additional computational cost. @@ -246,7 +246,7 @@ impl, Y: Array1> Las pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result, Failed> { let (n, p) = x.shape(); - if n <= p { + if n < p { return Err(Failed::fit( "Number of rows in X should be >= number of columns in X", )); @@ -369,6 +369,7 @@ impl, Y: Array1> Las #[cfg(test)] mod tests { use super::*; + use crate::linalg::basic::arrays::Array; use crate::linalg::basic::matrix::DenseMatrix; use crate::metrics::mean_absolute_error; @@ -448,6 +449,36 @@ mod tests { assert!(mean_absolute_error(&y_hat, &y) < 2.0); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_full_rank_x() { + // x: randn(3,3) * 10, demean, then round to 2 decimal points + // y = x @ [10.0, 0.2, -3.0], round to 2 decimal points + let param = LassoParameters::default() + .with_normalize(false) + .with_alpha(200.0); + let x = DenseMatrix::from_2d_array(&[ + &[-8.9, -2.24, 8.89], + &[-4.02, 8.89, 12.33], + &[12.92, -6.65, -21.22], + ]) + .unwrap(); + + let y = vec![-116.12, -75.41, 191.53]; + let w = Lasso::fit(&x, &y, param) + .unwrap() + .coefficients() + .iterator(0) + .copied() + .collect(); + + let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent + assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4 + } + // TODO: serialization for the new DenseMatrix needs to be implemented // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test]