diff --git a/algorithms/linfa-logistic/Cargo.toml b/algorithms/linfa-logistic/Cargo.toml index be32724c2..1b09a83e6 100644 --- a/algorithms/linfa-logistic/Cargo.toml +++ b/algorithms/linfa-logistic/Cargo.toml @@ -26,3 +26,4 @@ linfa = { version = "0.6.0", path = "../..", features=["serde"] } [dev-dependencies] approx = "0.4" linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["winequality"] } +rmp-serde = "1" diff --git a/algorithms/linfa-logistic/src/lib.rs b/algorithms/linfa-logistic/src/lib.rs index a50055a69..a0018dcf7 100644 --- a/algorithms/linfa-logistic/src/lib.rs +++ b/algorithms/linfa-logistic/src/lib.rs @@ -30,6 +30,7 @@ use ndarray::{ Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip, }; use ndarray_stats::QuantileExt; +use serde::{Deserialize, Serialize}; use std::default::Default; mod argmin_param; @@ -524,8 +525,8 @@ fn multi_logistic_grad>( } /// A fitted logistic regression which can make predictions -#[derive(PartialEq, Debug, Clone)] -pub struct FittedLogisticRegression { +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] +pub struct FittedLogisticRegression { threshold: F, intercept: F, params: Array1, @@ -610,8 +611,8 @@ impl> } /// A fitted multinomial logistic regression which can make predictions -#[derive(PartialEq, Debug, Clone)] -pub struct MultiFittedLogisticRegression { +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] +pub struct MultiFittedLogisticRegression { intercept: Array1, params: Array2, classes: Vec, @@ -685,8 +686,8 @@ impl> } } -#[derive(PartialEq, Debug, Clone)] -struct ClassLabel { +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] +struct ClassLabel { class: C, label: F, } @@ -1066,6 +1067,15 @@ mod test { &res.predict(dataset.records()), dataset.targets().as_single_targets() ); + + // Test serialization + let ser = rmp_serde::to_vec(&res).unwrap(); + let unser: FittedLogisticRegression = rmp_serde::from_slice(&ser).unwrap(); + + let x = array![[1.0]]; + let y_hat = unser.predict(&x); + + assert!(y_hat[0] == 0.0); } #[test]