Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating cuBLASLt #53437

Closed
wants to merge 1 commit into from
Closed

Conversation

philipphack
Copy link
Contributor

Adds support for the cuBLASLt library for GEMM operations on GPUs.

The library can be activated by setting the environment variable TF_USE_CUBLASLT=1.

@google-ml-butler google-ml-butler bot added the size:XL CL Change Size:Extra Large label Dec 15, 2021
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Dec 15, 2021
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Dec 16, 2021
@gbaned
Copy link
Contributor

gbaned commented Dec 31, 2021

@philipphack Can you please resolve conflicts? Thanks!

@gbaned gbaned added stat:awaiting response Status - Awaiting response from author and removed awaiting review Pull request awaiting review labels Dec 31, 2021
@philipphack
Copy link
Contributor Author

@gbaned Done.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jan 6, 2022
@gbaned gbaned requested a review from cheshire January 6, 2022 14:15
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Jan 6, 2022
@cheshire
Copy link
Member

cheshire commented Jan 6, 2022

Wow this is great, thanks!

Copy link
Member

@cheshire cheshire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left comments on the first half of the patch, the biggest feedback so far is that XLA can't depend on CUDA stuff like this directly, so it has to be wrapped in tensorflow/stream_executor (look at the recent cuDNN frontend integration for an example)

tensorflow/compiler/tests/qr_op_test.py Outdated Show resolved Hide resolved
tensorflow/compiler/tests/qr_op_test.py Outdated Show resolved Hide resolved
tensorflow/compiler/xla/service/gpu/BUILD Outdated Show resolved Hide resolved
PR Queue automation moved this from Assigned Reviewer to Reviewer Requested Changes Jan 6, 2022
@cheshire cheshire requested review from awpr and removed request for joker-eph and penpornk January 6, 2022 18:11
@tensorflowbutler tensorflowbutler removed the awaiting review Pull request awaiting review label Jan 8, 2022
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Jan 13, 2022
@philipphack
Copy link
Contributor Author

Thanks again for the review @cheshire. Can you PTAL?

tensorflow/compiler/xla/service/gpu/gemm_thunk.cc Outdated Show resolved Hide resolved
tensorflow/compiler/xla/service/gpu/gemm_thunk.cc Outdated Show resolved Hide resolved
tensorflow/core/kernels/gpu_utils.cc Outdated Show resolved Hide resolved
tensorflow/stream_executor/matmul_util.h Outdated Show resolved Hide resolved
@tensorflowbutler tensorflowbutler removed the awaiting review Pull request awaiting review label Jan 18, 2022
PR Queue automation moved this from Approved by Reviewer to Closed/Rejected Mar 10, 2022
@google-ml-butler google-ml-butler bot removed the awaiting review Pull request awaiting review label Mar 10, 2022
@SandSnip3r
Copy link
Contributor

Hmm, @philipphack, I know this was intended to be disabled by default, though that seems to not be the case. We've seen some out-of-memory tests failing and are rolling-back this change. I suspect that the DoGemm vs DoGemmLt changes in RunGemm might not be respecting the configuration flag. I'm looking into it currently, but would also like your input. Thanks!

@SandSnip3r SandSnip3r reopened this Mar 10, 2022
@SandSnip3r SandSnip3r reopened this Mar 10, 2022
PR Queue automation moved this from Closed/Rejected to Assigned Reviewer Mar 10, 2022
@philipphack
Copy link
Contributor Author

@SandSnip3r can you share additional information on the failing tests? I don't see how DoGemmLt can be reached unless the flag has been set. Having said that, the PR modifies the matrix solve op test to run with cuBLASLt enabled to verify its functionality.

copybara-service bot pushed a commit that referenced this pull request Mar 10, 2022
Rollback due to increased memory footprint

PiperOrigin-RevId: 433854142
@@ -60,7 +60,7 @@ def CompareOrthogonal(self, x, y, rank):
def CheckApproximation(self, a, q, r):
# Tests that a ~= q*r.
precision = self.AdjustedNorm(a - np.matmul(q, r))
self.assertTrue(np.all(precision < 10.0))
self.assertTrue(np.all(precision < 11.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cublaslt is disabled by default, and there are no other changes to this test, why the change in precision?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's in case cuBLASLt is manually turned on when running the test. The matrix solve op test is the only test that by default runs with cuBLASLt.

}
}

void InsertBasedOnScore(const Parameters& params, const Config& config) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does any code ever call this newly added function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not currently. The PR originally integrated cuBLASLt into native TF and XLA, before the TF part was removed following George's advice. I'll submit a PR for the second part which will make use of this function once the current PR has been merged. If it's a concern, I'll remove the function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine, you can leave it. I missed that part of the history and was curious.

@gbaned
Copy link
Contributor

gbaned commented Mar 16, 2022

@philipphack Can you please resolve conflicts? Thanks!

@SandSnip3r
Copy link
Contributor

SandSnip3r commented Mar 16, 2022

@SandSnip3r can you share additional information on the failing tests? I don't see how DoGemmLt can be reached unless the flag has been set. Having said that, the PR modifies the matrix solve op test to run with cuBLASLt enabled to verify its functionality.

@philipphack Yesterday I verified a big spike in memory usage (~2.25x, 120MB->270MB) from this change in one of our tests. I am asking around to see what is required to share this test, or one similar. The biggest concern the the moment is that there is any memory usage change at all with this PR, since it should be disabled by default. I am investigating this issue, but would also appreciate your help in doing so.

@philipphack
Copy link
Contributor Author

philipphack commented Mar 17, 2022

@SandSnip3r profiling of several tests, including //tensorflow/compiler/tests:qr_op_test, indicates identical device memory usage of TF master and cc25dbb when the XLA flag is not set.

@philipphack
Copy link
Contributor Author

@SandSnip3r can you run your test after setting the environment variables CUBLASLT_LOG_LEVEL=2 and CUBLASLT_LOG_FILE=LOGFILE.%i? This should generate logs of all cuBLASLt calls.

For the tests in tensorflow/compiler/tests, I see the same output for github/master and cc25dbb.

return false;
}
*config = iter->second.config;
return true;
}

void Insert(const Parameters& params, const Config& config) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of the logic in Insert has been moved to InsertBasedOnScore, but the latter function is never called. Was this logic important?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Reed. Can you see if the changes I made resolve the problem with your test?

It looks like the rollback is confusing github.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SandSnip3r was looking at the test, not me (although I think now a third person is looking at it). I happened to notice some logic in Insert was removed, but I don't think it was affecting the test. Anyway, the test is still being worked on, and we hope we can have an update soon.

It looks like you reverted most the PR. @SandSnip3r has been using the version of the PR when it was originally submitted to debug the test. Are any of the updates to the PR since then important?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And by "submitted" in the comment above, I mean "merged". @SandSnip3r is testing with e1ba3cd, which has since been rolled back. (@SandSnip3r, correct me if I'm wrong)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only update since the original merge was the removal of InsertBasedOnScore yesterday.

Github seems to see only the changes after the merge. It might be best to continue with a new PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@philipphack, you're welcome to make a new PR if you want, but I think it is necessary unless you're introducing new non-trivial changes. Your original PR was merged, and then rolled back due to performance issues. I have a pending roll back of the rollback once we resolve the issue. There are a handful of edits that I've made in order to make everything compliant with Google style guides and build rules. If it matters to you, I think the rollback of the rollback would no longer attribute the work to your username.

As @reedwm mentioned, we have another person looking into this specific failing test which originally motivated the rollback; not everyone internally has access to this test. I'll be sure to update you on any findings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-merging based on commit e1ba3cd is probably not the best way forward, since the autotuning logic should be fixed. Although this would otherwise be resolved in the planned follow-up PR on integrating cuBLASLt into native TF, it may negatively affect performance in the meantime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the "autotuning logic" fix?

Copy link
Contributor Author

@philipphack philipphack Mar 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Insert in n gpu_utils.h was originally replaced by InsertBasedOnScore. However, after removing related changes due to the split of the PR into two parts, this modification effectively changes the behavior of the autotuner.

I've submitted PR 55351 which also makes it clear that cuBLASLt is being integrated into XLA only.

PR Queue automation moved this from Assigned Reviewer to Closed/Rejected Mar 23, 2022
@philipphack philipphack reopened this Mar 23, 2022
PR Queue automation moved this from Closed/Rejected to Assigned Reviewer Mar 23, 2022
PR Queue automation moved this from Assigned Reviewer to Closed/Rejected Mar 23, 2022
@philipphack philipphack reopened this Mar 23, 2022
PR Queue automation moved this from Closed/Rejected to Assigned Reviewer Mar 23, 2022
PR Queue automation moved this from Assigned Reviewer to Closed/Rejected Mar 23, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
size:XL CL Change Size:Extra Large
Projects
No open projects
PR Queue
  
Closed/Rejected
Development

Successfully merging this pull request may close these issues.

None yet

10 participants