Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinomial logistic regression #159

Merged
merged 19 commits into from
Aug 31, 2021

Conversation

YuhanLiin
Copy link
Collaborator

Implementation of multinomial logistic regression. Also added some tests. Still need to add docs and possibly refactor.

@YuhanLiin YuhanLiin marked this pull request as draft August 16, 2021 05:05
@codecov-commenter
Copy link

codecov-commenter commented Aug 16, 2021

Codecov Report

Merging #159 (65dc913) into master (992938e) will increase coverage by 0.62%.
The diff coverage is 75.61%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #159      +/-   ##
==========================================
+ Coverage   59.84%   60.46%   +0.62%     
==========================================
  Files          84       84              
  Lines        7748     7923     +175     
==========================================
+ Hits         4637     4791     +154     
- Misses       3111     3132      +21     
Impacted Files Coverage Δ
algorithms/linfa-logistic/src/float.rs 0.00% <0.00%> (ø)
algorithms/linfa-logistic/src/argmin_param.rs 47.05% <36.36%> (+25.63%) ⬆️
algorithms/linfa-logistic/src/lib.rs 78.02% <77.77%> (+5.04%) ⬆️
src/correlation.rs 52.45% <0.00%> (-0.88%) ⬇️
algorithms/linfa-hierarchical/src/lib.rs 52.23% <0.00%> (-0.80%) ⬇️
algorithms/linfa-linear/src/glm.rs 46.66% <0.00%> (-0.40%) ⬇️
...linfa-clustering/src/gaussian_mixture/algorithm.rs 56.52% <0.00%> (-0.10%) ⬇️
src/dataset/mod.rs 88.09% <0.00%> (ø)
algorithms/linfa-nn/tests/nn.rs 85.13% <0.00%> (ø)
algorithms/linfa-pls/src/pls_generic.rs 69.76% <0.00%> (+0.27%) ⬆️
... and 5 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 992938e...65dc913. Read the comment docs.

@bytesnake
Copy link
Member

bytesnake commented Aug 22, 2021

would a review be helpful, despite being a draft?

@YuhanLiin YuhanLiin marked this pull request as ready for review August 23, 2021 20:06
@YuhanLiin
Copy link
Collaborator Author

YuhanLiin commented Aug 23, 2021

One observation I have of the multinomial implementation is that it tends to diverge when the input is not normalized well. Not sure if this is a natural property of the algorithm itself.
Specifically, the algorithm gets stuck on the wine quality data due to divergence, even though binomial regression works fine on that data.

@YuhanLiin
Copy link
Collaborator Author

The unnormalized wine quality also diverges on scikit-learn. However, the algorithm actually terminates since it hits the max iteration limit. With this implementation the algorithm just gets stuck on the first iteration of the solver forever. This likely has something to do with the linear search library we're using.

Copy link
Member

@bytesnake bytesnake left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work 👍

algorithms/linfa-logistic/src/argmin_param.rs Show resolved Hide resolved
self.initial_params = Some((params, intercept));
/// The `params` array must have at least the same number of rows as there are columns on the
/// feature matrix `x` passed to the `fit` method
pub fn initial_params(mut self, params: Array<F, D>) -> Self {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this try to achieve what we are doing with FitWith? (i.e. you can also initialize a linear model with these parameters and call fit_with on that)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I though FitWith was meant to be used for "online" versions of algorithms for faster fitting? This method existed before so I just kept it. It has the same purpose as setting the initial centroids when using KMeans, which is useful even when using FitWith.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but I think it's more intuitive if a fitted model can be constructed (i.e. having a FittedLogisticRegression::new function) with a given intercept and hyperplane. You can then update those initial weights with FitWith instead of putting them to the hyperparameters.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both FittedLogisticRegression and MultiFittedLogisticRegression require the set of labels to be specified at construction, which isn't info that's available before fitting. If we make FitWith perform a single iteration then the users will have to write their own loops and stopping conditions, which is more work. We could make FitWith perform the complete algorithm to solve this. Also, if we have the Fit version of the algorithm always start with param arrays of 0s, it'd be pretty much useless since the optimal starting params are random values, not 0s. I just don't think it makes sense to delegate the offline and online versions of the algorithm to the same impl. What we can do is have separate param types for the online and offline versions of the algorithm, both implementing FitWith. The offline impl would run the entire algorithm while the online impl runs a single step. This makes sense because the offline algorithm requires a few hyperparams that the online version doesn't need.

algorithms/linfa-logistic/src/lib.rs Show resolved Hide resolved
@YuhanLiin YuhanLiin merged commit e06f0be into rust-ml:master Aug 31, 2021
@YuhanLiin YuhanLiin deleted the multinomial-logistic branch November 21, 2021 17:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants