Skip to content

Add multi class model and Platt scaling#112

Merged
bytesnake merged 25 commits into
rust-ml:masterfrom
bytesnake:multi_class_model
Apr 25, 2021
Merged

Add multi class model and Platt scaling#112
bytesnake merged 25 commits into
rust-ml:masterfrom
bytesnake:multi_class_model

Conversation

@bytesnake
Copy link
Copy Markdown
Member

@bytesnake bytesnake commented Mar 30, 2021

Composing model for binary to multi-class transformation, WIP

  • finish MultiClassModel
  • add Platt scaling for proper SVM's probability values
  • add example
  • improve documentation

Example

// we have to specify that we want to predict probabilities
// `Svm::<_, bool>` would also be possible and avoid Platt scaling.
let params = Svm::<_, Pr>::params()
  .gaussian_kernel(30.0);

let model = train.one_vs_all()?
    .into_iter()
    .map(|(l, x)| (l, params.fit(&x).unwrap()))
    .collect::<MultiClassModel<_, _>>();

// predict with validation dataset, the prediction has type `usize`
let pred = model.predict(&valid);

// create and print a confusion matrix
let cm = pred.confusion_matrix(&train)?;
println!("{:?}", cm);

@bytesnake bytesnake marked this pull request as ready for review April 14, 2021 10:02
@bytesnake bytesnake changed the title [WIP] Add multi class model and Platt scaling Add multi class model and Platt scaling Apr 14, 2021
Copy link
Copy Markdown
Member

@Sauro98 Sauro98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work, nice to have a unified multi-class model 🚀

I would just add a note about the difference in the SVM params with bool or Pr also in the parameters' documentation so that the user doesn't have to go look in the example.

Comment thread algorithms/linfa-svm/src/classification.rs Outdated
Comment thread algorithms/linfa-svm/src/classification.rs Outdated
Comment thread algorithms/linfa-svm/examples/winequality_multi.rs
Comment thread algorithms/linfa-svm/examples/winequality_multi.rs Outdated
@quietlychris
Copy link
Copy Markdown
Member

I saw that you linked the paper at https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf in your explanation. It's totally possible that I missed it, but I couldn't find the first half of the

    // avoid numerical problems for large f_apb
    if f_apb >= 0.0 {
        Pr((-f_apb).exp() / (1.0 + (-f_apb).exp()))
    } else {
        Pr(1.0 / (1.0 + f_apb.exp()))
    }

loop that you use for numerical stability with f_apb in the numerator in either that article or on the Wikipedia link. From first principles, it makes sense that it would work, but just wasn't sure if that was something you found in another paper that you might want to link as well. It's probably not necessary either way, but just got me curious. Otherwise, everything else looks great :)

@bytesnake
Copy link
Copy Markdown
Member Author

Good work, nice to have a unified multi-class model rocket

I would just add a note about the difference in the SVM params with bool or Pr also in the parameters' documentation so that the user doesn't have to go look in the example.

good point 👍 I'm not 100% sure why the compiler can't figure out when to use bool vs Pr though

I saw that you linked the paper at https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf in your explanation. It's totally possible that I missed it, but I couldn't find the first half of the

    // avoid numerical problems for large f_apb
    if f_apb >= 0.0 {
        Pr((-f_apb).exp() / (1.0 + (-f_apb).exp()))
    } else {
        Pr(1.0 / (1.0 + f_apb.exp()))
    }

loop that you use for numerical stability with f_apb in the numerator in either that article or on the Wikipedia link. From first principles, it makes sense that it would work, but just wasn't sure if that was something you found in another paper that you might want to link as well. It's probably not necessary either way, but just got me curious. Otherwise, everything else looks great :)

yes you just have to multiply both sides with (-f_apb).exp(). I don't have a source right now, but probably similar to stable sigmoid implementations. Will find something :)

thank you both, I will complete the PR now and then merge

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 25, 2021

Codecov Report

Merging #112 (d4dd77f) into master (908efde) will decrease coverage by 0.17%.
The diff coverage is 55.41%.

❗ Current head d4dd77f differs from pull request most recent head 7a6cc20. Consider uploading reports for the commit 7a6cc20 to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master     #112      +/-   ##
==========================================
- Coverage   58.32%   58.15%   -0.18%     
==========================================
  Files          75       77       +2     
  Lines        6695     6813     +118     
==========================================
+ Hits         3905     3962      +57     
- Misses       2790     2851      +61     
Impacted Files Coverage Δ
algorithms/linfa-svm/src/lib.rs 22.89% <0.00%> (-2.43%) ⬇️
algorithms/linfa-svm/src/regression.rs 75.80% <ø> (ø)
algorithms/linfa-svm/src/solver_smo.rs 37.23% <ø> (ø)
src/composing/multi_target_model.rs 73.91% <ø> (ø)
algorithms/linfa-svm/src/classification.rs 79.68% <51.72%> (-5.17%) ⬇️
src/composing/multi_class_model.rs 52.77% <52.77%> (ø)
src/composing/platt_scaling.rs 58.02% <58.02%> (ø)
src/dataset/impl_dataset.rs 44.23% <100.00%> (+1.37%) ⬆️
src/dataset/mod.rs 86.20% <100.00%> (-2.10%) ⬇️
... and 13 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 908efde...7a6cc20. Read the comment docs.

@bytesnake bytesnake merged commit b7c31c5 into rust-ml:master Apr 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants