Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearRegression: add support for multiple targets #4988

Merged

Conversation

ahendriksen
Copy link
Contributor

LinearRegression did not have support for target vectors with multiple columns previously. This PR adds support.

@ahendriksen ahendriksen requested a review from a team as a code owner November 9, 2022 18:28
@github-actions github-actions bot added the Cython / Python Cython or Python issue label Nov 9, 2022
@ahendriksen ahendriksen added 3 - Ready for Review Ready for review by team feature request New feature or request Cython / Python Cython or Python issue non-breaking Non-breaking change and removed Cython / Python Cython or Python issue labels Nov 9, 2022
@tfeher
Copy link
Contributor

tfeher commented Nov 10, 2022

closes #3850

Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Allard for the PR! It looks great overall, I just have a few comments.

python/cuml/linear_model/linear_regression.pyx Outdated Show resolved Hide resolved
python/cuml/linear_model/linear_regression.pyx Outdated Show resolved Hide resolved
python/cuml/tests/test_linear_model.py Outdated Show resolved Hide resolved
python/cuml/tests/test_linear_model.py Outdated Show resolved Hide resolved
@ahendriksen
Copy link
Contributor Author

Thanks for catching the skipped tests. That was an oversight on my part.

Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Allard for the update, the PR looks good to me!

@caryr35 caryr35 added this to PR-WIP in v22.12 Release via automation Nov 14, 2022
@caryr35 caryr35 moved this from PR-WIP to PR-Needs review in v22.12 Release Nov 14, 2022
python/cuml/linear_model/linear_regression.pyx Outdated Show resolved Hide resolved
python/cuml/linear_model/base.pyx Outdated Show resolved Hide resolved
python/cuml/linear_model/base.pyx Outdated Show resolved Hide resolved
@csadorf
Copy link
Contributor

csadorf commented Nov 15, 2022

Forgot to mention in my review, I think the generate_docstring decorator would also require an update, but might be difficult if the change does not apply to all LinearPredictMixin children:

'shape': '(n_samples, 1)'})

@csadorf
Copy link
Contributor

csadorf commented Nov 15, 2022

@ahendriksen Thanks a lot for tackling this! 👍 I have provided a few comments. My main concern is with potentially unnecessary conversions, would appreciate if those could be addressed.

@ahendriksen
Copy link
Contributor Author

On a representative benchmark, this PR speeds up LinearRegression by 20-50x compared to using a loop in Python. This was how multi-target linear regression has been implemented in practice so far.

Results on Volta for the below script show that the new code is faster both when the number of targets is large and when it is small.

loop: 13.24 seconds                       
new: 0.22 seconds                  

# and with fit_intercept=True
loop: 15.02 seconds     
new: 0.20 seconds

# with n_targets = 2
loop: 0.01 seconds
new: 0.01 seconds

# with n_targets = 10
loop: 0.04 seconds
new: 0.01 seconds
import cupy as cp
from cuml.linear_model import LinearRegression
from time import perf_counter as timer
from contextlib import contextmanager

@contextmanager
def time(name):
    # Code to acquire resource, e.g.:
    start = timer()
    yield
    duration = timer() - start
    print(f"{name}: {duration:0.2f} seconds")

n_features = 3
n_samples = 91_000
n_targets = 5_000

X = cp.random.normal(size=(n_samples, n_features))
y = cp.random.normal(size=(n_samples, n_targets))
out1 = cp.zeros(y.shape)
out2 = cp.zeros(y.shape)

# Create linear regression instance that can be reused.
lr = LinearRegression(fit_intercept=False, output_type="cupy", algorithm="svd")

with time("loop"):
    for i in range(n_targets):
        lr.fit(X, y[:, i])
        out1[:, i] = lr.predict(X)

with time("new"):
    lr.fit(X, y)
    out2[:] = lr.predict(X)

@codecov-commenter
Copy link

Codecov Report

❗ No coverage uploaded for pull request base (branch-22.12@6500897). Click here to learn what that means.
Patch has no changes to coverable lines.

Additional details and impacted files
@@               Coverage Diff               @@
##             branch-22.12    #4988   +/-   ##
===============================================
  Coverage                ?   79.44%           
===============================================
  Files                   ?      184           
  Lines                   ?    11698           
  Branches                ?        0           
===============================================
  Hits                    ?     9293           
  Misses                  ?     2405           
  Partials                ?        0           
Flag Coverage Δ
dask 45.93% <0.00%> (?)
non-dask 68.97% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@csadorf
Copy link
Contributor

csadorf commented Nov 15, 2022

I think we should also update the tags of the estimator class by adding

@staticmethod
def _more_static_tags():
   return {"multioutput": True}

to the base class.

@ahendriksen ahendriksen force-pushed the fea-lin-reg-multiple-targets branch 4 times, most recently from 88402bf to c18d7ff Compare November 15, 2022 15:23
@ahendriksen
Copy link
Contributor Author

Thank you for the review! Due to the removal of the conversions, the code became another factor of 2 faster. Apologies for the wip commits. I have a bit of a hobbled workflow when working with pyx files.

@ahendriksen
Copy link
Contributor Author

Can this PR be merged?

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM. Thanks Allard!

v22.12 Release automation moved this from PR-Needs review to PR-Reviewer approved Nov 16, 2022
@cjnolet
Copy link
Member

cjnolet commented Nov 16, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 0f19b92 into rapidsai:branch-22.12 Nov 16, 2022
v22.12 Release automation moved this from PR-Reviewer approved to Done Nov 16, 2022
jakirkham pushed a commit to jakirkham/cuml that referenced this pull request Feb 27, 2023
LinearRegression did not have support for target vectors with multiple columns previously. This PR adds support.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#4988
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review Ready for review by team Cython / Python Cython or Python issue feature request New feature or request non-breaking Non-breaking change
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet

6 participants