Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linalg.lstsq to use svd to compute the result for rank deficient matrices. #125110

Closed
wants to merge 237 commits into from

Conversation

ZelboK
Copy link
Contributor

@ZelboK ZelboK commented Apr 28, 2024

Fixes #117122

This PR adds the logic so that in the case of rank deficient matrices, it can fallback to an SVD backend for batched mode. A big thank you to @tvercaut for the well written issue and suggestion on how to approach the problem.

Summary:

  1. At the time of writing this I haven't touched non-batched yet. I am hoping to get some feedback before proceeding.
  2. I believe there should be eyes on how specifically we want to fallback to SVD as the implementation when we run into rank deficient matrices.

Please keep in mind this is my 2nd PR to pytorch, and I've never really used pytorch. I'm learning independently through digging deep in the internals so I may make some obvious mistakes. Please forgive!

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @mcarilli @ptrblck @leslie-fang-intel @voznesenskym @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @LucasLLC

Copy link

pytorch-bot bot commented Apr 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125110

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 454d0d4 with merge base 4d063c8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Apr 28, 2024
@ZelboK ZelboK marked this pull request as draft April 28, 2024 00:57
@ZelboK ZelboK marked this pull request as ready for review April 28, 2024 01:22
@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 28, 2024

@janeyx99 @ptrblck

2nd PR! 🎉 Sorry it took me some days. I'm learning Pytorch internals independently so I'm still learning the codebase.

Also I'm curious to know if there's a community like slack or discord for Pytorch?

@lezcano
Copy link
Collaborator

lezcano commented Apr 28, 2024

mind cleaning up all the spurious new lines and the PR in general?

@ZelboK ZelboK force-pushed the feat-improve-driver-linalg-lstq branch from 06e42e8 to 7372645 Compare April 28, 2024 09:46
@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 28, 2024

@lezcano

My apologies! I've cleaned it up. I missed some new lines from when I was cleaning up my debugging/experimenting code so I could understand the codebase.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks mostly good.

Needs tests in test_linalg.py and updating the docs noting that this gelss mode is also supported.

aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
ZelboK and others added 2 commits April 28, 2024 09:24
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
torch/linalg/__init__.py Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 28, 2024

@lezcano

So when it comes to the tests, what kind of test did you think would be appropriate, aside from checking that it no longer throws? I can add gelss as a driver to be used in test_linalg_lstsq_batch_broadcasting for example and assure its results are as expected. I'm not too familiar with the test suites yet so hoping for guidance here.

Edit: Workflow runs exposed two failing tests for CPU and complex lstsq computations. I didn't notice I didn't build with LAPACK, so these tests were skipped. Will look into it now.

@ZelboK ZelboK force-pushed the feat-improve-driver-linalg-lstq branch from da93358 to c71e504 Compare April 28, 2024 23:43
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing, just add a path that tests this driver in the relevant tests that tests the other drivers. We may even already have a test that tests this driver for CPU.

test/test_linalg.py Outdated Show resolved Hide resolved
Copy link

CLA Missing ID CLA Not Signed

@lezcano
Copy link
Collaborator

lezcano commented May 19, 2024

bad rebase. Too many people are tagged now. Open a new PR.

@lezcano lezcano closed this May 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: amp (automated mixed precision) autocast module: cpu CPU specific problem (e.g., perf, algorithm) module: distributed_checkpoint module: dynamo module: inductor module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: quantization release notes category release notes: sparse release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve behaviour of torch.linalg.lstsq on CUDA GPU for rank defficient matrices