Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exporting & Loading Trained Model? #290

Closed
Bastian1110 opened this issue Feb 17, 2023 · 7 comments
Closed

Exporting & Loading Trained Model? #290

Bastian1110 opened this issue Feb 17, 2023 · 7 comments

Comments

@Bastian1110
Copy link

Is there a method to save a trained (GaussianNb)model and the load it?

I'm just learning how to use Rust, I just managed to implement a Gaussian Naive Bayes classifier model, is there any way to use the "predict" method without having to train the whole model again?
I know that in libraries like Sklearn you can export them models and then load them in .pkl formats, is there a similar implementation in linfa?

Thank you so much!

@YuhanLiin
Copy link
Collaborator

The models should implement the Serde trait, so you can serialize them using something like ciborium

@Bastian1110
Copy link
Author

I'm having a little problem hehe, I'm trying to serialize the model (I'm using the example of linfa_svm) but I don't know if I'm using the correct syntax since I get the error in the line where I use cbor:

the trait bound `MultiClassModel<ndarray::ArrayBase<ndarray::data_repr::OwnedRepr<f64>, ndarray::dimension::dim::Dim<[usize; 2]>>, usize>: serde::ser::Serialize` is not satisfied
the following other types implement trait `serde::ser::Serialize`:

This is the code I'm using (linfa_svm/examples/winequality_multi.rs) :

let model = train
        .one_vs_all()?
        .into_iter()
        .map(|(l, x)| (l, params.fit(&x).unwrap()))
        .collect::<MultiClassModel<_, _>>();
 let pred = model.predict(&valid);

//Trying to serialize model
 let save_model = cbor!(model).unwrap();

Could you give me a more detailed example? I would appreciate it too much!

@YuhanLiin
Copy link
Collaborator

Seems like MultiClassModel and linfa-bayes have no support for Serde. Weird. We'll need to add that.

@Bastian1110
Copy link
Author

Oh! I will try with normal SVM then, thank you!

@Bastian1110
Copy link
Author

One last question,
I already managed to serialize the SVM model without Multiclass to CBOR, using ciborium, I also managed to de-serialize it in another file and convert it to Value, the last step would be to convert it from Value to SVM, any idea how to do this?

Code for creating and exporting the model :

    let model = Svm::<_, bool>::params().pos_neg_weights(50000., 5000.).gaussian_kernel(80.0).fit(&train)?;

    //Serializing the trained model with ciborium
    let value_model : Value = cbor!(model).unwrap();
    let mut vec_model : Vec<u8> = Vec::new();
    let _cebor_writer = ciborium::ser::into_writer(&value_model, &mut vec_model);

    //Esporting it to a .cbor file
    let path: &Path = Path::new("./model.cbor");
    fs::write(path, vec_model).unwrap();

Attempt to use the trained model in other .rs program :

    let mut file = File::open("./model.cbor").unwrap();
    let mut data: Vec<u8> = Vec::new();
    file.read_to_end(&mut data).unwrap();

    let model_value : Value = ciborium::de::from_reader::<Value, _>(&data[..]).unwrap();
    let model: Svm<_, bool> = model_value.deserialized().unwrap(); // Error 

But I keep getting error when trying to converting form ciborium::Value to SVM, the rust-analyzer suggests : consider specifying the generic argument: ::<Svm<_, bool>>, I guess I have to pass the SVM serve-deserializer but I don't know how to do that.

I know this has nothing to do with linfa, but I really think that exporting and importing the models can be very useful.

Thank you!

@Bastian1110
Copy link
Author

My bad, it turns out the example SVM model uses Svm<f64, bool> not Svm<_, bool>. I only changed the line to :

let model: Svm<f65, bool> = model_value.deserialized().unwrap();

And it works super cool!

@coolstudio1678
Copy link

One last question, I already managed to serialize the SVM model without Multiclass to CBOR, using ciborium, I also managed to de-serialize it in another file and convert it to Value, the last step would be to convert it from Value to SVM, any idea how to do this?

Code for creating and exporting the model :

    let model = Svm::<_, bool>::params().pos_neg_weights(50000., 5000.).gaussian_kernel(80.0).fit(&train)?;

    //Serializing the trained model with ciborium
    let value_model : Value = cbor!(model).unwrap();
    let mut vec_model : Vec<u8> = Vec::new();
    let _cebor_writer = ciborium::ser::into_writer(&value_model, &mut vec_model);

    //Esporting it to a .cbor file
    let path: &Path = Path::new("./model.cbor");
    fs::write(path, vec_model).unwrap();

Attempt to use the trained model in other .rs program :

    let mut file = File::open("./model.cbor").unwrap();
    let mut data: Vec<u8> = Vec::new();
    file.read_to_end(&mut data).unwrap();

    let model_value : Value = ciborium::de::from_reader::<Value, _>(&data[..]).unwrap();
    let model: Svm<_, bool> = model_value.deserialized().unwrap(); // Error 

But I keep getting error when trying to converting form ciborium::Value to SVM, the rust-analyzer suggests : consider specifying the generic argument: ::<Svm<_, bool>>, I guess I have to pass the SVM serve-deserializer but I don't know how to do that.

I know this has nothing to do with linfa, but I really think that exporting and importing the models can be very useful.

Thank you!

I use cbor!(model), it show errors,how to resolve?

the trait bound MultiClassModel<ndarray::ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, Option<&str>>: serde::ser::Serialize is not satisfied
the following other types implement trait serde::ser::Serialize:
bool
char
isize
i8
i16
i32
i64
i128
and 196 othersrustcClick for full compiler diagnostic
lib.rs(222, 42): Actual error occurred here
lib.rs(222, 9): required by a bound introduced by this call
ser.rs(435, 35): required by a bound in value::ser::<impl ciborium::Value>::serialized

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants