Skip to content

Commit

Permalink
[xla:gpu][NFC] Explicitly rewrite AddressComputationFusion in custom …
Browse files Browse the repository at this point in the history
…call tests

AddressComputationFusionRewriter is now part of `RunHloPasses`, we need to explicitly call it to transform the HLO in order to keep tests meaningful.

PiperOrigin-RevId: 620529073
  • Loading branch information
tyb0807 authored and tensorflower-gardener committed Mar 30, 2024
1 parent 92b03bd commit b2e993e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ xla_test(
"//xla/service:custom_call_target_registry",
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service/gpu:address_computation_fusion_rewriter",
"//xla/stream_executor",
"//xla/stream_executor:device_description",
"//xla/stream_executor/gpu:gpu_types_header",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/address_computation_fusion_rewriter.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -887,14 +888,14 @@ TEST_F(AddressComputationFusionTest, CustomCallSimple) {
TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

debug_options.set_xla_gpu_enable_address_computation_fusion(true);
hlo_config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));
AddressComputationFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);

EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt),
error_spec,
/*run_hlo_passes=*/false));
error_spec, /*run_hlo_passes=*/false));
}

static absl::Status SubBuffers(se::Stream* stream, ffi::BufferBase src0,
Expand Down Expand Up @@ -993,9 +994,12 @@ TEST_F(AddressComputationFusionTest, CustomCallWithTuple) {
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

AddressComputationFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);

EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt),
error_spec,
/*run_hlo_passes=*/false));
error_spec, /*run_hlo_passes=*/false));
}

static absl::Status NoOp(se::Stream* stream, ffi::BufferBase operand) {
Expand Down Expand Up @@ -1039,6 +1043,10 @@ TEST_F(AddressComputationFusionTest, NilTuple) {
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

AddressComputationFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);

EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt),
error_spec,
/*run_hlo_passes=*/false));
Expand Down Expand Up @@ -1079,6 +1087,10 @@ TEST_F(AddressComputationFusionTest, CustomCallLegacyAPI) {
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

AddressComputationFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);

EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt),
error_spec,
/*run_hlo_passes=*/false));
Expand Down Expand Up @@ -1113,6 +1125,10 @@ TEST_F(AddressComputationFusionTest, NilTupleLegacyAPI) {
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

AddressComputationFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);

EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt),
error_spec,
/*run_hlo_passes=*/false));
Expand Down

0 comments on commit b2e993e

Please sign in to comment.