Skip to content

Commit

Permalink
Merge pull request #1 from postgresml/levkk-add-ser-for-logistic
Browse files Browse the repository at this point in the history
Add serialization for LogisticRegression
  • Loading branch information
levkk committed Oct 6, 2022
2 parents d34313c + fb7a1fa commit 0930322
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions algorithms/linfa-logistic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ linfa = { version = "0.6.0", path = "../..", features=["serde"] }
[dev-dependencies]
approx = "0.4"
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["winequality"] }
rmp-serde = "1"
22 changes: 16 additions & 6 deletions algorithms/linfa-logistic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use ndarray::{
Dimension, IntoDimension, Ix1, Ix2, RemoveAxis, Slice, Zip,
};
use ndarray_stats::QuantileExt;
use serde::{Deserialize, Serialize};
use std::default::Default;

mod argmin_param;
Expand Down Expand Up @@ -524,8 +525,8 @@ fn multi_logistic_grad<F: Float, A: Data<Elem = F>>(
}

/// A fitted logistic regression which can make predictions
#[derive(PartialEq, Debug, Clone)]
pub struct FittedLogisticRegression<F: Float, C: PartialOrd + Clone> {
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
pub struct FittedLogisticRegression<F, C: PartialOrd + Clone> {
threshold: F,
intercept: F,
params: Array1<F>,
Expand Down Expand Up @@ -610,8 +611,8 @@ impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
}

/// A fitted multinomial logistic regression which can make predictions
#[derive(PartialEq, Debug, Clone)]
pub struct MultiFittedLogisticRegression<F: Float, C: PartialOrd + Clone> {
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
pub struct MultiFittedLogisticRegression<F, C: PartialOrd + Clone> {
intercept: Array1<F>,
params: Array2<F>,
classes: Vec<C>,
Expand Down Expand Up @@ -685,8 +686,8 @@ impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
}
}

#[derive(PartialEq, Debug, Clone)]
struct ClassLabel<F: Float, C: PartialOrd> {
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
struct ClassLabel<F, C: PartialOrd> {
class: C,
label: F,
}
Expand Down Expand Up @@ -1066,6 +1067,15 @@ mod test {
&res.predict(dataset.records()),
dataset.targets().as_single_targets()
);

// Test serialization
let ser = rmp_serde::to_vec(&res).unwrap();
let unser: FittedLogisticRegression<f32, f32> = rmp_serde::from_slice(&ser).unwrap();

let x = array![[1.0]];
let y_hat = unser.predict(&x);

assert!(y_hat[0] == 0.0);
}

#[test]
Expand Down

0 comments on commit 0930322

Please sign in to comment.