Skip to content

Commit

Permalink
PR #12942: [GPU] Fix cuDNN GEMM test tolerances.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#12942

Use the maximum absolute difference observed on 20 runs of these tests with different seed values.
Copybara import of the project:

--
c438b08ea7240c23ae98bc8dcf4ef45fa6d2e89c by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Fix cuDNN GEMM test tolerances.

Use the maximum absolute difference observed on 20 runs of these tests with different seed values.

Merging this change closes #12942

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12942 from openxla:fix_test_cudnn c438b08ea7240c23ae98bc8dcf4ef45fa6d2e89c
PiperOrigin-RevId: 636188097
  • Loading branch information
sergachev authored and tensorflower-gardener committed May 22, 2024
1 parent 000a07c commit 5ce7a1e
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 48 deletions.
1 change: 0 additions & 1 deletion tensorflow/lite/schema/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ filegroup(
name = "tflite_internal_cc_3p_api_deps_src",
srcs = [
":schema_fbs_srcs",
":schema_utils.cc",
":schema_utils.h",
],
visibility = [
Expand Down
18 changes: 0 additions & 18 deletions third_party/triton/temporary/linear_layout_compose_asan.patch

This file was deleted.

1 change: 0 additions & 1 deletion third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ internal patch during the next triton integration process.
"""

temporary_patch_list = [
"//third_party/triton/temporary:linear_layout_compose_asan.patch",
]
4 changes: 2 additions & 2 deletions third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l
def repo():
"""Imports Triton."""

TRITON_COMMIT = "cl634675237"
TRITON_SHA256 = "7151d057ee8443c2f45cbe18a7435a42f37e18f562e5d238b844b6e09fc560e6"
TRITON_COMMIT = "cl635840438"
TRITON_SHA256 = "707101b2e8366e63e80150c26f8ab660052099c91ca0c4fa4c713607fa75f318"
tf_http_archive(
name = "triton",
sha256 = TRITON_SHA256,
Expand Down

This file was deleted.

1 change: 0 additions & 1 deletion third_party/xla/third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ internal patch during the next triton integration process.
"""

temporary_patch_list = [
"//third_party/triton/temporary:linear_layout_compose_asan.patch",
]
4 changes: 2 additions & 2 deletions third_party/xla/third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l
def repo():
"""Imports Triton."""

TRITON_COMMIT = "cl634675237"
TRITON_SHA256 = "7151d057ee8443c2f45cbe18a7435a42f37e18f562e5d238b844b6e09fc560e6"
TRITON_COMMIT = "cl635840438"
TRITON_SHA256 = "707101b2e8366e63e80150c26f8ab660052099c91ca0c4fa4c713607fa75f318"
tf_http_archive(
name = "triton",
sha256 = TRITON_SHA256,
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/cudnn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ ENTRY r {
ROOT r = bf16[192,128]{1,0} fusion(p0, p1), kind=kCustom, calls=fusion1,
backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
})",
ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
ErrorSpec{/*aabs=*/1, /*arel=*/1e-3}));
}

TEST_F(CuDnnFusionLevel3Test,
Expand All @@ -629,7 +629,7 @@ ENTRY r {
ROOT r = bf16[4,3,16,128]{2,1,3,0} fusion(p0, p1), kind=kCustom, calls=fusion1,
backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
})",
ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
ErrorSpec{/*aabs=*/1, /*arel=*/1e-3}));
}

class ElementwiseTest : public CuDnnFusionExecutionTest,
Expand Down
10 changes: 7 additions & 3 deletions third_party/xla/xla/service/gpu/ir_emitter_triton_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@ absl::Status CreateTritonPipeline(
pm.addPass(mt::gpu::createOptimizeDotOperandsPass(ccCuda.IsAtLeastAmpere()));
pm.addPass(mlir::createCSEPass());

pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps,
config.num_ctas, ccAsInt));

// Even though we don't run on pre-Ampere architectures anymore, we keep this
// check for consistency with the upstream pipeline
if (ccCuda.IsAtLeastAmpere()) {
pm.addPass(mt::gpu::createCombineTensorSelectAndIfPass());
pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps,
config.num_ctas, ccAsInt));
}
if (!ccCuda.IsAtLeastHopper()) {
pm.addPass(mt::gpu::createPrefetchPass());
}
Expand Down

0 comments on commit 5ce7a1e

Please sign in to comment.