diff --git a/algorithms/linfa-logistic/src/hyperparams.rs b/algorithms/linfa-logistic/src/hyperparams.rs index 1447198db..1705aed3a 100644 --- a/algorithms/linfa-logistic/src/hyperparams.rs +++ b/algorithms/linfa-logistic/src/hyperparams.rs @@ -9,10 +9,12 @@ use serde::{Deserialize, Serialize}; /// A generalized logistic regression type that specializes as either binomial logistic regression /// or multinomial logistic regression. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct LogisticRegressionParams(LogisticRegressionValidParams); +#[serde(bound(deserialize = "D: Deserialize<'de>"))] +pub struct LogisticRegressionParams(LogisticRegressionValidParams); #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct LogisticRegressionValidParams { +#[serde(bound(deserialize = "D: Deserialize<'de>"))] +pub struct LogisticRegressionValidParams { pub(crate) alpha: F, pub(crate) fit_intercept: bool, pub(crate) max_iterations: u64, diff --git a/algorithms/linfa-logistic/src/lib.rs b/algorithms/linfa-logistic/src/lib.rs index a0018dcf7..edc3b66cd 100644 --- a/algorithms/linfa-logistic/src/lib.rs +++ b/algorithms/linfa-logistic/src/lib.rs @@ -526,7 +526,8 @@ fn multi_logistic_grad>( /// A fitted logistic regression which can make predictions #[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] -pub struct FittedLogisticRegression { +#[serde(bound(deserialize = "C: Deserialize<'de>"))] +pub struct FittedLogisticRegression { threshold: F, intercept: F, params: Array1,