Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
- WARNING: Breaking changes!
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.


## [0.4.0] - 2023-04-05

## Added
Expand Down
2 changes: 2 additions & 0 deletions src/linear/elastic_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;

for i in 0..p {
Expand All @@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;

for i in 0..p {
Expand Down
165 changes: 112 additions & 53 deletions src/linear/lasso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ pub struct LassoParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: bool,
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -86,6 +89,12 @@ impl LassoParameters {
self.max_iter = max_iter;
self
}

/// If false, force the intercept parameter (beta_0) to be zero.
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
self.fit_intercept = fit_intercept;
self
}
}

impl Default for LassoParameters {
Expand All @@ -95,6 +104,7 @@ impl Default for LassoParameters {
normalize: true,
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
}
}
}
Expand All @@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
{
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
coefficients: None,
intercept: None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
Expand Down Expand Up @@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub fit_intercept: Vec<bool>,
}

/// Lasso grid search iterator
Expand All @@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
current_fit_intercept: usize,
}

impl IntoIterator for LassoSearchParameters {
Expand All @@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
current_fit_intercept: 0,
}
}
}
Expand All @@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
&& self.current_tol == self.lasso_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
{
return None;
}
Expand All @@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
tol: self.lasso_search_parameters.tol[self.current_tol],
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
};

if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
Expand All @@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
{
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter = 0;
self.current_fit_intercept += 1;
} else {
self.current_alpha += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
self.current_fit_intercept += 1;
}

Some(next)
Expand All @@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
fit_intercept: vec![default_params.fit_intercept],
}
}
}
Expand Down Expand Up @@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;

for (j, col_std_j) in col_std.iter().enumerate().take(p) {
w[j] /= *col_std_j;
}

let mut b = TX::zero();
let b = if parameters.fit_intercept {
let mut xw_mean = TX::zero();
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
xw_mean += w[i] * *col_mean_i;
}

for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
b += w[i] * *col_mean_i;
}

b = TX::from_f64(y.mean_by()).unwrap() - b;
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
} else {
None
};
(X::from_column(&w), b)
} else {
let mut optimizer = InteriorPointOptimizer::new(x, p);
Expand All @@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;

(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
(
X::from_column(&w),
if parameters.fit_intercept {
Some(TX::from_f64(y.mean_by()).unwrap())
} else {
None
},
)
};

Ok(Lasso {
intercept: Some(b),
intercept: b,
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
Expand Down Expand Up @@ -378,30 +416,28 @@ mod tests {
let parameters = LassoSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
fit_intercept: vec![false, true],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);

let mut iter = parameters.clone().into_iter();
for current_fit_intercept in 0..parameters.fit_intercept.len() {
for current_max_iter in 0..parameters.max_iter.len() {
for current_alpha in 0..parameters.alpha.len() {
let next = iter.next().unwrap();
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
assert_eq!(
next.fit_intercept,
parameters.fit_intercept[current_fit_intercept]
);
}
}
}
assert!(iter.next().is_none());
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
Expand All @@ -427,6 +463,17 @@ mod tests {
114.2, 115.7, 116.9,
];

(x, y)
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
let (x, y) = get_example_x_y();

let y_hat = Lasso::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
Expand All @@ -441,6 +488,7 @@ mod tests {
normalize: false,
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
},
)
.and_then(|lr| lr.predict(&x))
Expand Down Expand Up @@ -479,35 +527,46 @@ mod tests {
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_fit_intercept() {
let (x, y) = get_example_x_y();
let fit_result = Lasso::fit(
&x,
&y,
LassoParameters {
alpha: 0.1,
normalize: false,
tol: 1e-8,
max_iter: 1000,
fit_intercept: false,
},
)
.unwrap();

let w = fit_result.coefficients().iterator(0).copied().collect();
// by sklearn LassoLars. coordinate descent doesn't converge well
let expected_w = vec![
0.18335684,
0.02106526,
0.00703214,
-1.35952542,
0.09295222,
0.,
];
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
assert_eq!(fit_result.intercept, None);
}

// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]);

// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];

// let (x, y) = get_lasso_sample_x_y();
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();

// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
Expand Down
7 changes: 6 additions & 1 deletion src/linear/lasso_optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
lambda: T,
max_iter: usize,
tol: T,
fit_intercept: bool,
) -> Result<Vec<T>, Failed> {
let (n, p) = x.shape();
let p_f64 = T::from_usize(p).unwrap();
Expand All @@ -61,7 +62,11 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
let mu = T::two();

// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
let y = if fit_intercept {
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
} else {
y.to_owned()
};

let mut max_ls_iter = 100;
let mut pitr = 0;
Expand Down