From d89e338c848ad421c763a47639f81ca42c7f2840 Mon Sep 17 00:00:00 2001 From: Georeth Zhou Date: Wed, 26 Nov 2025 11:07:25 +0800 Subject: [PATCH 1/2] Fix LASSO (#342) * change loss function in doc to match code * allow `n == p` case --- src/linear/lasso.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index 2919b025..80dd6cc6 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -9,7 +9,7 @@ //! //! Lasso coefficient estimates solve the problem: //! -//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\] +//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\] //! //! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy, //! but is able to solve them with high accuracy with relatively small additional computational cost. @@ -246,7 +246,7 @@ impl, Y: Array1> Las pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result, Failed> { let (n, p) = x.shape(); - if n <= p { + if n < p { return Err(Failed::fit( "Number of rows in X should be >= number of columns in X", )); From 195d893cbd02cf1b5c349e6bfa59438637d9459f Mon Sep 17 00:00:00 2001 From: Zhou Xiaozhou Date: Thu, 27 Nov 2025 16:23:06 +0800 Subject: [PATCH 2/2] lasso add test_full_rank_x --- src/linear/lasso.rs | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index 80dd6cc6..9057df56 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -369,6 +369,7 @@ impl, Y: Array1> Las #[cfg(test)] mod tests { use super::*; + use crate::linalg::basic::arrays::Array; use crate::linalg::basic::matrix::DenseMatrix; use crate::metrics::mean_absolute_error; @@ -448,6 +449,36 @@ mod tests { assert!(mean_absolute_error(&y_hat, &y) < 2.0); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_full_rank_x() { + // x: randn(3,3) * 10, demean, then round to 2 decimal points + // y = x @ [10.0, 0.2, -3.0], round to 2 decimal points + let param = LassoParameters::default() + .with_normalize(false) + .with_alpha(200.0); + let x = DenseMatrix::from_2d_array(&[ + &[-8.9, -2.24, 8.89], + &[-4.02, 8.89, 12.33], + &[12.92, -6.65, -21.22], + ]) + .unwrap(); + + let y = vec![-116.12, -75.41, 191.53]; + let w = Lasso::fit(&x, &y, param) + .unwrap() + .coefficients() + .iterator(0) + .copied() + .collect(); + + let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent + assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4 + } + // TODO: serialization for the new DenseMatrix needs to be implemented // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test]