Skip to content

FP8 Convolutions in XLA#60807

Merged
copybara-service[bot] merged 4 commits intotensorflow:masterfrom
philipphack:u_fp8_conv_xla
Jul 28, 2023
Merged

FP8 Convolutions in XLA#60807
copybara-service[bot] merged 4 commits intotensorflow:masterfrom
philipphack:u_fp8_conv_xla

Conversation

@philipphack
Copy link
Copy Markdown
Contributor

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the F8E4M3FN and F8E5M2 data types and x_scale, w_scale and y_scale are their scaling factors.

@google-ml-butler google-ml-butler bot added the size:XL CL Change Size:Extra Large label Jun 7, 2023
@philipphack
Copy link
Copy Markdown
Contributor Author

CC @reedwm, @nluehr.

@reedwm reedwm self-requested a review June 7, 2023 20:53
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Jun 7, 2023
@gbaned gbaned added the comp:xla XLA label Jun 8, 2023
Copy link
Copy Markdown
Contributor

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding FP8 support to convolutions! Sorry for taking so long to review this.

Also the comments are in a weird order since I went back and forth between the files a lot when reviewing this. I would view the comments in the "Files changed" tab.

@philipphack philipphack requested a review from reedwm July 19, 2023 23:51
}
}

if (pattern_level == 1) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misunderstood what you were doing before, when I commented

Instead of having this pattern_level concept, can we just iterate over the users again and check for the convert+clamp?

I understand now that you're trying to match a convert of a clamp, which requires looking at the user of the user. I'm still not a fan of the pattern_level concept though, since it's confusing. The best way to address this would be to return a GraphString without the convert+clamp. Then create a FuseConvertToF8 function that is called from CudnnFusedConvRewriter::Run, where you find the existing instruction with a graph string and append the conversion to F8.

If you'd rather handle everything in this function, you can directly get the user of the user to see if that matches, instead of using recursion and pattern_level. E.g. you can do:

if (user->user_count() == 1 && Match(user->users()[0], m::Convert(...))) {...}

Granted, when we eventually match amax calculations, we might need to go deeper, e.g. checking the user of the user of the user. I think doing so directly is better than recursion.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the context of having multiple users, the recursive approach is advantageous and more straightforward. Can we revisit this issue after you've seen the Amax case?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll bring this up again when you add amax support if I still think the non-recursive approach is better.

Comment on lines +149 to +148
if (GetCudaComputeCapability().IsAtLeast(
se::CudaComputeCapability::HOPPER)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check both custom_call_string and serialized_graph_string on pre-Hopper, using only the two passes (GpuConvRewriter and CudnnFusedConvRewriter) instead of all the passes? In both the Hopper and non-Hopper case, we can call RunAndFilecheckHloRewrite, and on Hopper only, we can call RunAndCompare. In the Hopper case, you can also call RunFileCheck but only need to do a simple sanity check, such as that the graph string is correct, since other passes may modify things like layout which would make the custom_call_string not match.

This is similar to what we do in gemm_rewrite_test, where we only call RunAndCompare on Hopper

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that the solution we arrived at? As I understand it, you don't want to verify the final layout even on Hopper systems which in my opinion renders the test somewhat incomplete. I don't think this is directly comparable to the GEMM case where layout conversions play less of a role.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unsure why layout is important here, compared to the gemm case. Running GpuConvRewriter and CudnnFusedConvRewriter on pre_hlo_string is still causing the FP8 rewrite to happen even if layout assignment doesn't run, right?

Even layout assignment is important, maybe see what transformations it is doing to pre_hlo_string, and just putting the resulting layouts directly in the HLO strings in the test, instead of relying on layout assignment to run.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My perspective is that the XLA unit tests are usually based on running the compiler pipeline. When that's not possible, we can do at partial testing by running only the relevant pass in some artificial setting that we can't easily extend to the full pipeline. It's less clear to me though why we'd want to deviate from the normal approach and restrict the testing in cases where we don't have to as well.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to deviate from the normal approach is that we want to test as much as possible on non-Hopper, and right now, the PR only tests the graph string and not things like the custom_call_target, the dim_labels, etc.

But I'll accept only testing the graph string for now, we can reconsider if the tests get broken later due to a lack of Hopper CI.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can discard the first part of the custom_call FileCheck string on non-Hopper systems and still compare the configuration of the Custom Call.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the first part of the custom_call FileCheck string? Is this the f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) part and is that part different if you don't run the rest of the passes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the order of the dimensions changes.

tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnTensor(
cudnn_frontend::Tensor original, int64_t uid, dnn::DataType dtype,
bool is_virtual = false) {
return tsl::errors : Internal("Not implemented.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

State that copying a cudnn tensor requires cudnn 8.8 in the error.

Also in cudnn_fused_conv_rewriter.cc, you should check CUDNN_VERSION in addition to CUDA_VERSION to avoid this error.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be supported but the functionality is only used when we require cuDNN version of at least 8.9 and I can't easily test it. One option would be to give the clone overload of CreateCudnnTensor its own 8.9 version guard instead of sharing it with the existing overload and remove this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine either keeping it as-is or adding an 8.9 version check inside thedefinition of CreateCudnnTensor.

Either way, check CUDNN_VERSION in cudnn_fused_conv_rewriter.cc though.

}
}

if (pattern_level == 1) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll bring this up again when you add amax support if I still think the non-recursive approach is better.

Comment on lines +149 to +148
if (GetCudaComputeCapability().IsAtLeast(
se::CudaComputeCapability::HOPPER)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unsure why layout is important here, compared to the gemm case. Running GpuConvRewriter and CudnnFusedConvRewriter on pre_hlo_string is still causing the FP8 rewrite to happen even if layout assignment doesn't run, right?

Even layout assignment is important, maybe see what transformations it is doing to pre_hlo_string, and just putting the resulting layouts directly in the HLO strings in the test, instead of relying on layout assignment to run.

tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnTensor(
cudnn_frontend::Tensor original, int64_t uid, dnn::DataType dtype,
bool is_virtual = false) {
return tsl::errors : Internal("Not implemented.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine either keeping it as-is or adding an 8.9 version check inside thedefinition of CreateCudnnTensor.

Either way, check CUDNN_VERSION in cudnn_fused_conv_rewriter.cc though.

@philipphack philipphack requested a review from reedwm July 21, 2023 21:14
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Jul 21, 2023
@reedwm
Copy link
Copy Markdown
Contributor

reedwm commented Jul 21, 2023

Can you resolve conflicts?

@google-ml-butler google-ml-butler bot removed the ready to pull PR ready for merge process label Jul 22, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jul 22, 2023
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Jul 23, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jul 23, 2023
copybara-service bot pushed a commit to google/tsl that referenced this pull request Jul 26, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#60807 from philipphack:u_fp8_conv_xla da22a881a3d24fd4f357207034ba6c596aa414d0
PiperOrigin-RevId: 550604841
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Jul 26, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#60807 from philipphack:u_fp8_conv_xla da22a881a3d24fd4f357207034ba6c596aa414d0
PiperOrigin-RevId: 550604841
copybara-service bot pushed a commit to google/tsl that referenced this pull request Jul 27, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#60807 from philipphack:u_fp8_conv_xla da22a881a3d24fd4f357207034ba6c596aa414d0
PiperOrigin-RevId: 551346059
copybara-service bot pushed a commit to google/tsl that referenced this pull request Jul 28, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#60807 from philipphack:u_fp8_conv_xla da22a881a3d24fd4f357207034ba6c596aa414d0
PiperOrigin-RevId: 551346059
copybara-service bot pushed a commit to google/tsl that referenced this pull request Jul 28, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#60807 from philipphack:u_fp8_conv_xla da22a881a3d24fd4f357207034ba6c596aa414d0
PiperOrigin-RevId: 551346059
copybara-service bot pushed a commit to google/tsl that referenced this pull request Jul 28, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#60807 from philipphack:u_fp8_conv_xla da22a881a3d24fd4f357207034ba6c596aa414d0
PiperOrigin-RevId: 551346059
copybara-service bot pushed a commit to google/tsl that referenced this pull request Jul 28, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#60807 from philipphack:u_fp8_conv_xla da22a881a3d24fd4f357207034ba6c596aa414d0
PiperOrigin-RevId: 551346059
copybara-service bot pushed a commit to google/tsl that referenced this pull request Jul 28, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#60807 from philipphack:u_fp8_conv_xla da22a881a3d24fd4f357207034ba6c596aa414d0
PiperOrigin-RevId: 551346059
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Jul 28, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

PiperOrigin-RevId: 551973730
copybara-service bot pushed a commit to google/tsl that referenced this pull request Jul 28, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

PiperOrigin-RevId: 551973730
@copybara-service copybara-service bot merged commit 631bbed into tensorflow:master Jul 28, 2023
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Jul 28, 2023
Imported from GitHub PR tensorflow/tensorflow#60807

Enables scaled convolutions of the form

(X, W, x_scale, w_scale, y_scale) -> Y,

where the input X, the filter W and the output Y are based on the `F8E4M3FN` and `F8E5M2` data types and x_scale, w_scale and y_scale are their scaling factors.
Copybara import of the project:

--
8a30aa731c21612fe098a6b620a54922578611c2 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
caade6453519ad2531ebcf8f206e40187a1687ca by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
ecd080bd6c64682f6bee62f4455ea2c37c279f26 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

--
da22a881a3d24fd4f357207034ba6c596aa414d0 by Philipp Hack <phack@nvidia.com>:

Support for FP8 convolutions in XLA.

Merging this change closes #60807

PiperOrigin-RevId: 551973730
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting review Pull request awaiting review comp:xla XLA ready to pull PR ready for merge process size:XL CL Change Size:Extra Large

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants