From 7b0010bd929cfda5c646fbc27dd257dea5973d3e Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 2 Apr 2024 15:41:18 -0700 Subject: [PATCH] [xla][gpu] Implement pipelined-p2p-rewriter. This pass rewrite pipelined point-to-point communication by rotating the SendDone and RecvDone operations in a while-body to the beginning of the next iteration. The SendDone and RecvDone operations for the last iteration are moved to the while-op calling computation, after the while-op. Add the pass to the GPU post-scheduler pipeline. This is another approach to achieve the code pattern to pipeline two Send-Recv chains decomposed from a collective-permute with a source-target pair cycle for performance. The pipelined Send-Recv pattern puts SendDone and RecvDone before Send and Recv in the while-body, and if we generate such code pattern too early in the GPU compilation pipeline, copy-insertion may generate copies of Send causing Send and SendDone with different buffers and thus correctness problem. PiperOrigin-RevId: 621317739 --- third_party/xla/xla/service/gpu/BUILD | 45 ++ .../service/gpu/gpu_algebraic_simplifier.cc | 53 +- .../service/gpu/gpu_algebraic_simplifier.h | 21 +- .../gpu/gpu_algebraic_simplifier_test.cc | 33 +- .../xla/xla/service/gpu/gpu_compiler.cc | 7 + .../xla/xla/service/gpu/gpu_compiler_test.cc | 127 ++++ .../xla/service/gpu/pipelined_p2p_rewriter.cc | 702 ++++++++++++++++++ .../xla/service/gpu/pipelined_p2p_rewriter.h | 133 ++++ .../gpu/pipelined_p2p_rewriter_test.cc | 509 +++++++++++++ 9 files changed, 1621 insertions(+), 9 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc create mode 100644 third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h create mode 100644 third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 4fe4a9341d9797..2f80141ba54c92 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3399,6 +3399,7 @@ cc_library( ]), deps = if_gpu_is_configured([ ":gpu_p2p_pipeliner", + ":pipelined_p2p_rewriter", ":collective_permute_cycle_decomposer", ":address_computation_fusion_rewriter", ":algorithm_checker", @@ -3648,6 +3649,7 @@ xla_test( "//xla/service:pattern_matcher_gmock", "//xla/service:xla_debug_info_manager", "//xla/service/gpu:autotuner_util", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", @@ -3945,12 +3947,15 @@ cc_library( "gpu_algebraic_simplifier.h", ], deps = [ + ":triton_support", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:algebraic_simplifier", "//xla/service:hlo_pass", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -3963,6 +3968,7 @@ xla_cc_test( ":gpu_algebraic_simplifier", "//xla/hlo/ir:hlo", "//xla/service:algebraic_simplifier", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", @@ -6138,3 +6144,42 @@ xla_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "pipelined_p2p_rewriter", + srcs = ["pipelined_p2p_rewriter.cc"], + hdrs = ["pipelined_p2p_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "pipelined_p2p_rewriter_test", + srcs = ["pipelined_p2p_rewriter_test.cc"], + deps = [ + ":pipelined_p2p_rewriter", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc index f0cb1d238fd6af..857da0bdc7bba0 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc +++ b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc @@ -15,13 +15,58 @@ limitations under the License. #include "xla/service/gpu/gpu_algebraic_simplifier.h" +#include + +#include "absl/log/check.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/triton_support.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" namespace xla::gpu { +bool IsDotSupportedByGemmFusion(const HloInstruction* dot, + se::GpuComputeCapability compute_capability) { + auto supported_output_type = [&](const PrimitiveType t) { + auto cuda_compute_capability = + std::get_if(&compute_capability); + auto rocm_compute_capability = + std::get_if(&compute_capability); + + CHECK(cuda_compute_capability || rocm_compute_capability); + + switch (t) { + case F16: + case F32: + return true; + case BF16: + if (cuda_compute_capability) { + return true; + } + if (rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + return false; + default: + return false; + } + }; + + if (!supported_output_type(dot->shape().element_type())) { + return false; + } + + if (!IsTritonSupportedDataType(dot->operand(0)->shape().element_type(), + compute_capability) || + !IsTritonSupportedDataType(dot->operand(1)->shape().element_type(), + compute_capability)) { + return false; + } + return true; +} + bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce( const HloInstruction* hlo) { if (!options_.enable_dot_strength_reduction()) { @@ -44,7 +89,13 @@ bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce( rhs->shape().rank()); // Strength-reduce vector-vector dots since they are not supported by // GemmFusion. - return lhs_is_vector && rhs_is_vector; + if (lhs_is_vector && rhs_is_vector) { + return true; + } + + // If GemmFusion cannot handle this dot, we should strength-reduce it so that + // it can be handled by the fusion pipeline. + return !IsDotSupportedByGemmFusion(dot, compute_capability_); } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h index 4b8d9a30949227..855359654395a0 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h +++ b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h @@ -16,12 +16,15 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_ #define XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_ +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" namespace xla::gpu { @@ -30,16 +33,23 @@ class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor { public: explicit GpuAlgebraicSimplifierVisitor( const AlgebraicSimplifierOptions& options, + se::GpuComputeCapability compute_capability, AlgebraicSimplifier* simplifier) - : AlgebraicSimplifierVisitor(options, simplifier) {} + : AlgebraicSimplifierVisitor(options, simplifier), + compute_capability_(std::move(compute_capability)) {} bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override; + + private: + se::GpuComputeCapability compute_capability_; }; class GpuAlgebraicSimplifier : public AlgebraicSimplifier { public: - explicit GpuAlgebraicSimplifier(const AlgebraicSimplifierOptions& options) - : AlgebraicSimplifier(options) {} + explicit GpuAlgebraicSimplifier(const AlgebraicSimplifierOptions& options, + se::GpuComputeCapability compute_capability) + : AlgebraicSimplifier(options), + compute_capability_(std::move(compute_capability)) {} using HloPassInterface::Run; absl::StatusOr Run(HloModule* module, @@ -48,7 +58,7 @@ class GpuAlgebraicSimplifier : public AlgebraicSimplifier { XLA_VLOG_LINES( 2, "GpuAlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; - GpuAlgebraicSimplifierVisitor visitor(options_, this); + GpuAlgebraicSimplifierVisitor visitor(options_, compute_capability_, this); for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { if (visitor.Run(comp, options_, this)) { changed = true; @@ -58,6 +68,9 @@ class GpuAlgebraicSimplifier : public AlgebraicSimplifier { 2, "GpuAlgebraicSimplifier::Run(), after:\n" + module->ToString()); return changed; } + + private: + se::GpuComputeCapability compute_capability_; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc index e73fae38ac35b8..b0a5cc6a44440a 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/algebraic_simplifier.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -43,8 +44,9 @@ ENTRY entry { const HloInstruction* dot = module->entry_computation()->root_instruction(); AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(true); - GpuAlgebraicSimplifier simplifier(options); - GpuAlgebraicSimplifierVisitor visitor(options, &simplifier); + se::CudaComputeCapability ampere(8, 0); + GpuAlgebraicSimplifier simplifier(options, ampere); + GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier); EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot)); } @@ -63,10 +65,33 @@ ENTRY entry { const HloInstruction* dot = module->entry_computation()->root_instruction(); AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(true); - GpuAlgebraicSimplifier simplifier(options); - GpuAlgebraicSimplifierVisitor visitor(options, &simplifier); + se::CudaComputeCapability ampere(8, 0); + GpuAlgebraicSimplifier simplifier(options, ampere); + GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier); EXPECT_FALSE(visitor.ShouldStrengthReduceDotToReduce(dot)); } +TEST_F(GpuAlgebraicSimplifierTest, + DotWithTypeUnsupportedByGemmFusionShouldBeStrengthReduced) { + const std::string& hlo_string = R"( +HloModule m + +ENTRY entry { + p0 = c64[32, 5, 7] parameter(0) + p1 = c64[32, 5] parameter(1) + ROOT dot = c64[32,7] dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloInstruction* dot = module->entry_computation()->root_instruction(); + AlgebraicSimplifierOptions options; + options.set_enable_dot_strength_reduction(true); + se::CudaComputeCapability ampere(8, 0); + GpuAlgebraicSimplifier simplifier(options, ampere); + GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier); + EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot)); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 998e302d9fd977..69b4f1159c12ea 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -157,6 +157,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/move_copy_to_users.h" +#include "xla/service/gpu/pipelined_p2p_rewriter.h" #include "xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h" #include "xla/service/gpu/reduction_degenerate_dim_remover.h" #include "xla/service/gpu/reduction_dimension_grouper.h" @@ -2199,6 +2200,12 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( { HloPassPipeline pipeline("post-scheduling-passes"); + if (module->config() + .debug_options() + .xla_gpu_enable_pipelined_collectives() || + module->config().debug_options().xla_gpu_enable_pipelined_p2p()) { + pipeline.AddPass(); + } HloPredicate is_nop = HloPredicateIsOp; diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 4e886b46e5408c..26e8bcf7c322ac 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/xla_debug_info_manager.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" @@ -395,6 +396,132 @@ ENTRY main { triton_disabled_module->computation_count()); } +TEST_F(GpuCompilerTest, CollectivePermuteDecompositionAndPipelining) { + const char* kModuleStr = R"( +HloModule cp + +cond { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(11) + ROOT result = pred[] compare(count, ub), direction=LT + } + +body { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(%param), index=0 + send-data = get-tuple-element(%param), index=1 + + recv-data = f32[1, 1024, 1024] collective-permute(send-data), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, channel_id=1 + + // The computation code that uses the current recv-data and + // produces the send-data for the next iteration. + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + replica = u32[] replica-id() + c10 = u32[] constant(10) + sum = u32[] add(replica, c10) + sum2 = u32[] add(sum, count) + conv = f32[] convert(sum2) + p = f32[1, 1024, 1024] broadcast(conv), dimensions={} + b = f32[1, 1024, 1024] add(p, recv-data) + c = f32[1, 1024, 1024] multiply(b, b) + d = f32[1, 1024, 1024] tan(c) + s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} + + ROOT result = (u32[], f32[1, 1024, 1024]) tuple(new_count, s) +} + +ENTRY test_computation { + c0 = u32[] constant(0) + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + while_init = (u32[], f32[1, 1024, 1024]) tuple(c0, init) + while_result = (u32[], f32[1, 1024, 1024]) while(while_init), body=body, condition=cond + ROOT result = f32[1, 1024, 1024] get-tuple-element(while_result), index=1 +} +)"; + + // In the expected string, we skip some detail on the while-init tuple due to + // b/333572009. + const char* kExpected = R"( +CHECK: %body.1 (param.2.0: (u32[], f32[1,1024,1024], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), u32[])) -> (u32[], f32[1,1024,1024], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), u32[]) { +CHECK: %param.2.0 = parameter(0) +CHECK: %get-tuple-element.38 = get-tuple-element(%param.2.0), index=2 +CHECK: %get-tuple-element.39 = get-tuple-element(%param.2.0), index=3 +CHECK-DAG: %get-tuple-element.22 = get-tuple-element(%param.2.0), index=0 +CHECK-DAG: %recv-done.3 = recv-done(%get-tuple-element.38), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} +CHECK-DAG: %get-tuple-element.25 = get-tuple-element(%recv-done.3), index=0 +CHECK: %loop_multiply_tan_fusion = fusion +CHECK: %send-done.3 = send-done(%get-tuple-element.39), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} +CHECK: %custom-call.1.0 = custom-call +CHECK: %after-all.3.0 = after-all() +CHECK{LITERAL}: %recv.2.0 = recv(%after-all.3.0), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, control-predecessors={%custom-call.1.0} +CHECK{LITERAL}: %send.2.0 = send(%bitcast.119.0, %after-all.3.0), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, control-predecessors={%recv.2.0} +CHECK: %loop_add_fusion = fusion +CHECK: %loop_add_fusion.1 = fusion +CHECK: ROOT %tuple.13 = tuple(%loop_add_fusion.1, %bitcast.119.0, %recv.2.0, %send.2.0, %loop_add_fusion) +CHECK: } + +CHECK: %cond.1 (cond_param.1: (u32[], f32[1,1024,1024], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), u32[])) -> pred[] { +CHECK: %cond_param.1 = parameter(0) +CHECK: %get-tuple-element.5.0 = get-tuple-element(%cond_param.1), index=0 +CHECK: ROOT %loop_compare_fusion = fusion(%get-tuple-element.5.0), kind=kLoop, calls=%fused_compare +CHECK: } + +CHECK: ENTRY %test_computation () -> f32[1,1024,1024] { +CHECK: %after-all.1.0 = after-all() +CHECK: %loop_broadcast_fusion = fusion +CHECK{LITERAL}: %recv.1.0 = recv(%after-all.1.0), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"} +CHECK{LITERAL}: %send.1.0 = send(%loop_broadcast_fusion, %after-all.1.0), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, control-predecessors={%recv.1.0} +CHECK: %copy_fusion = fusion +CHECK: %get-tuple-element.36 = get-tuple-element(%copy_fusion), index=0 +CHECK: %get-tuple-element.37 = get-tuple-element(%copy_fusion), index=1 +CHECK: %bitcast.170 = bitcast(%get-tuple-element.36) +CHECK: %bitcast.171 = bitcast(%get-tuple-element.37) +CHECK: %while-init = tuple +CHECK-SAME: %recv.1.0, %send.1.0 +CHECK{LITERAL}: %while-result = while(%while-init), condition=%cond.1, body=%body.1, backend_config={"known_trip_count":{"n":"10"}} +CHECK: %get-tuple-element.40 = get-tuple-element(%while-result), index=2 +CHECK: %get-tuple-element.41 = get-tuple-element(%while-result), index=3 +CHECK: %recv-done.4 = recv-done(%get-tuple-element.40), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} +CHECK: %get-tuple-element.7.0 = get-tuple-element(%recv-done.4), index=0 +CHECK: %loop_multiply_tan_fusion.1 = fusion +CHECK: %get-tuple-element.13 = get-tuple-element(%loop_multiply_tan_fusion.1), index=0 +CHECK: %get-tuple-element.14 = get-tuple-element(%loop_multiply_tan_fusion.1), index=1 +CHECK: %bitcast.150.0 = bitcast(%get-tuple-element.13) +CHECK: %bitcast.155.0 = bitcast(%get-tuple-element.14) +CHECK: %send-done.4 = send-done(%get-tuple-element.41), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} +CHECK: %custom-call.3.0 = custom-call(%bitcast.150.0, %bitcast.155.0), custom_call_target="__cublas$gemm" +CHECK: %get-tuple-element.10.0 = get-tuple-element(%custom-call.3.0), index=0 +CHECK: ROOT %bitcast.5.0 = bitcast(%get-tuple-element.10.0) +CHECK: } +)"; + + HloModuleConfig config; + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true); + debug_options.set_xla_gpu_collective_permute_decomposer_threshold(1); + debug_options.set_xla_gpu_enable_pipelined_p2p(true); + debug_options.set_xla_gpu_enable_triton_gemm(false); + config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + GetOptimizedModule(std::move(module))); + TF_ASSERT_OK(Schedule(optimized_module.get())); + + HloPrintOptions options; + options.set_print_operand_shape(false); + options.set_print_result_shape(false); + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matched, + RunFileCheck(optimized_module->ToString(options), kExpected)); + EXPECT_TRUE(filecheck_matched); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc new file mode 100644 index 00000000000000..49a1b3ac06f7ec --- /dev/null +++ b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc @@ -0,0 +1,702 @@ + +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/pipelined_p2p_rewriter.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { +// Maps a computation to a boolean that indicates whether there is any +// collective operations directly or indirectly invoked in the computation. +using CollectiveInComputation = + absl::flat_hash_map; + +using InstructionVector = HloInstruction::InstructionVector; + +// Records starting index and the ending index of a pipelined while-op. They +// are the indices of the while-loop operand. +struct PipelinedP2PInfo { + int64_t opnd_start; + int64_t opnd_end; +}; + +// Returns whether the instruction is a collective operation. +bool IsCollectiveOp(const HloInstruction* op) { + HloOpcode opcode = op->opcode(); + // TODO(NVIDIA/4364298): The information is recorded in b/309639264. + // we need to avoid custom-calls to overlap with Send/Recv to workaround the + // bug. Remove custom-calls here when the bug is fixed. + if (opcode == HloOpcode::kCustomCall) { + return true; + } + + return hlo_query::IsCollectiveCommunicationOp(opcode) || + opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv; +} + +// Returns whether the instruction may invoke collective operations directly +// or indirectly. +bool MayInvokeCollectiveOp( + const HloInstruction* hlo, + const CollectiveInComputation& collective_in_computation) { + if (IsCollectiveOp(hlo)) { + return true; + } + for (HloComputation* callee : hlo->called_computations()) { + auto collective_in_comp = collective_in_computation.find(callee); + CHECK(collective_in_comp != collective_in_computation.end()); + if (collective_in_comp->second) { + return true; + } + } + return false; +} + +// Returns the unique get-tuple-element user with the given idx or nullptr if +// there isn't such a unique user. +HloInstruction* FindUniqueGTEUserWithIndex(const HloInstruction* op, + int64_t idx) { + CHECK(op->shape().IsTuple()); + + HloInstruction* gte = nullptr; + for (auto user : op->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + continue; + } + if (user->tuple_index() == idx) { + if (gte == nullptr) { + gte = user; + } else { + return nullptr; + } + } + } + return gte; +} + +// Returns whether there is any get-tuple-element user with the given idx. +bool HasGTEUserWithIndex(const HloInstruction* op, int64_t idx) { + CHECK(op->shape().IsTuple()); + + for (auto user : op->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + continue; + } + if (user->tuple_index() == idx) { + return true; + } + } + return false; +} + +// Returns the instruction hidden behind a trivial tuple or `op`. This allows +// the discovery of recv-done for the following case, for which the indirection +// would have been removed by tuple-simplification. +// gte.0 = f32[1,1024,1024] get-tuple-element(recv-done), index=0 +// gte.1 = token get-tuple-element(recv-done.p), index=1 +// op = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1) +// +// TODO(bixia): investigate the possible of implementing +// m::TrivialTuple(m::RecvDone(&instr)) as suggested by code review. +HloInstruction* MaySkipTrivialTuple(HloInstruction* op) { + if (op->opcode() != HloOpcode::kTuple) { + return op; + } + HloInstruction* hidden_op = nullptr; + for (auto opnd : op->mutable_operands()) { + if (opnd->opcode() != HloOpcode::kGetTupleElement) { + return op; + } + if (hidden_op == nullptr) { + hidden_op = opnd->mutable_operand(0); + } else if (opnd->mutable_operand(0) != hidden_op) { + return op; + } + } + return hidden_op; +} + +// This routine is similar to the non-const version above except that the +// the given instruction is used for pattern checking only and can't be mutated. +const HloInstruction* MaySkipTrivialTuple(const HloInstruction* op) { + // Use const_cast to avoid repeating the non-const version above to find + // operands of the instruction through operands() instead of + // mutable_operands(). + return MaySkipTrivialTuple(const_cast(op)); +} + +// Finds a consecutive block of balanced SendDone/RecvDone in the while_init +// of a while-loop, assuming its while_init is a tuple. +std::optional +FindConsecutiveAndBalanceBlockOfSendDoneRecvDone( + const HloInstruction* while_init) { + PipelinedP2PInfo pipelined_p2p_info{0, 0}; + // Return whether the first SendDone/RecvDone has been seen. + auto has_started = [&]() { + return pipelined_p2p_info.opnd_start != pipelined_p2p_info.opnd_end; + }; + // Record the difference between the number of SendDone and RecvDone in a + // consecutive block. + int difference = 0; + // If SendDone/RecvDone exists in a consecutive block in the while_init + // tuple, find such block. + for (int64_t i = 0; i < while_init->operand_count(); ++i) { + const HloInstruction* op = while_init->operand(i); + if ((op->opcode() == HloOpcode::kRecvDone || + op->opcode() == HloOpcode::kSendDone) && + op->frontend_attributes().map().count(kSendRecvPipelineAttr) > 0) { + if (op->opcode() == HloOpcode::kRecvDone) { + difference++; + } else { + difference--; + } + if (!has_started()) { + pipelined_p2p_info.opnd_start = i; + } + pipelined_p2p_info.opnd_end = i + 1; + } else { + if (has_started()) { + VLOG(10) << "End a consecutive block"; + break; + } + } + } + + if (difference != 0) { + VLOG(10) << "Mismatch number of SendDone and RecvDone: " << difference; + return std::nullopt; + } + + if (has_started()) { + // Check for SendDone/RecvDone outside the consecutive block. + for (int64_t i = pipelined_p2p_info.opnd_end; + i < while_init->operand_count(); ++i) { + const HloInstruction* op = while_init->operand(i); + if (op->opcode() == HloOpcode::kRecvDone || + op->opcode() == HloOpcode::kSendDone) { + VLOG(10) << "SendDone/RecvDone outside the consecutive block"; + return std::nullopt; + break; + } + } + } + + if (!has_started()) { + VLOG(10) << "No SendDone/RecvDone in while-init "; + return std::nullopt; + } + + return pipelined_p2p_info; +} + +// Checks whether the while-op, its while-body and while-condition have a +// recognized pipelined pattern. If a pipelined pattern is found, returns the +// first and last indices for the pipelined instruction in the while-init tuple. +// For pipelined Send/Recv to work, the SendDone/RecvDone doesn't have to be in +// a consecutive block, but this simplifies the implementation and is the +// pattern that the current gpu-p2p-pipeliner generated. +// +// As a summary, this is what the routine looks for: +// +// . The while-init has a tuple with a single user. +// . The while-init has a consecutive block of SendDone and RecvDone. The +// numbers of SendDone and RecvDone are the same, and there isn't any other +// SendDone and RecvDone outside the block. +// . The while-body has a single tuple parameter. +// . For the while-op result tuple and the while-body parameter tuple: +// The index corresponding to the index of SendDone in while-init should not +// correspond to any get-element-tuple user. +// The index corresponding to the index of RecvDone in while-init should +// correspond to a single get-element-tuple user. +// . In the while-body result tuple, the operand with an index corresponding to +// the index in the while-init SendDone and RecvDone should also be a SendDone +// or RecvDone. +// +// TODO(bixia): support pipelined SendDone/RecvDone not in a consecutive block +// if the gpu-p2p-pipeliner will ever generate such code in the future. +std::optional FindPipelinedP2P( + const HloInstruction* while_op) { + VLOG(10) << "while_op: " << while_op->ToString(); + const HloInstruction* while_init = while_op->while_init(); + if (while_init->opcode() != HloOpcode::kTuple || + while_init->user_count() != 1) { + return std::nullopt; + } + + // The while-body and while-condition should have one parameter of a tuple + // shape. + const HloComputation* while_body = while_op->while_body(); + const HloComputation* while_condition = while_op->while_condition(); + if (while_body->num_parameters() != 1 || + while_condition->num_parameters() != 1) { + return std::nullopt; + } + + std::optional pipelined_p2p_info = + FindConsecutiveAndBalanceBlockOfSendDoneRecvDone(while_init); + if (!pipelined_p2p_info.has_value()) { + return std::nullopt; + } + + VLOG(10) << "opnd_start " << pipelined_p2p_info->opnd_start << " opnd_end " + << pipelined_p2p_info->opnd_end; + + // In the while-result or while-body parameter, the index for RecvDone should + // correspond to one get-tuple-element user and the index for SendDone should + // not correspond to any get-tuple-element user. + for (int64_t i = pipelined_p2p_info->opnd_start; + i < pipelined_p2p_info->opnd_end; ++i) { + const HloInstruction* op = while_init->operand(i); + if (op->opcode() == HloOpcode::kRecvDone) { + if (!FindUniqueGTEUserWithIndex(while_op, i)) { + VLOG(10) << "While result get-tuple-element user with index " << i + << " not unique"; + return std::nullopt; + } + if (!FindUniqueGTEUserWithIndex(while_body->parameter_instruction(0), + i)) { + VLOG(10) << "While-body parameter get-tuple-element user with index " + << i << " not unique"; + return std::nullopt; + } + } else { + CHECK(op->opcode() == HloOpcode::kSendDone); + if (HasGTEUserWithIndex(while_op, i) || + HasGTEUserWithIndex(while_body->parameter_instruction(0), i)) { + VLOG(10) << "SendDone with index " << i << " has unexpected users"; + return std::nullopt; + } + } + } + + // The element in the while-body result tuple corresponding to the pipelined + // SendDone/RecvDone in the while-init have the same opcode. + const HloInstruction* root = while_body->root_instruction(); + for (int64_t i = pipelined_p2p_info->opnd_start; + i < pipelined_p2p_info->opnd_end; ++i) { + const HloInstruction* op_init = while_init->operand(i); + const HloInstruction* op_root = root->operand(i); + op_root = MaySkipTrivialTuple(op_root); + if (op_init->opcode() != op_root->opcode()) { + VLOG(10) << "Mismatching opcode, op_init: " << op_init->ToString() + << " op_root: " << op_root->ToString(); + return std::nullopt; + } + } + + return pipelined_p2p_info.value(); +} + +absl::Status RemoveOpFromParent(HloInstruction* op) { + TF_RETURN_IF_ERROR(op->DropAllControlDeps()); + TF_RETURN_IF_ERROR(op->parent()->RemoveInstruction(op)); + return absl::OkStatus(); +} + +absl::Status ReplaceOpInSequence(HloInstruction* old_op, HloInstruction* new_op, + HloInstructionSequence& instruction_sequence) { + VLOG(10) << "old_op: " << old_op->ToString(); + VLOG(10) << "new_op: " << new_op->ToString(); + instruction_sequence.replace_instruction(old_op, new_op); + return RemoveOpFromParent(old_op); +} + +absl::Status ReplaceUsesAndUpdateSequence( + HloInstruction* old_op, HloInstruction* new_op, + HloInstructionSequence& instruction_sequence, bool diff_shape = false) { + VLOG(10) << "old_op: " << old_op->ToString(); + VLOG(10) << "new_op: " << new_op->ToString(); + if (diff_shape) { + TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWithDifferentShape(new_op)); + } else { + TF_RETURN_IF_ERROR(old_op->ReplaceAllUsesWith(new_op)); + } + return ReplaceOpInSequence(old_op, new_op, instruction_sequence); +} + +absl::Status ReplaceUsesAndUpdateSequence( + const InstructionVector& old_ops, const InstructionVector& new_ops, + HloInstructionSequence& instruction_sequence) { + CHECK(old_ops.size() == new_ops.size()); + for (int64_t i = 0; i < old_ops.size(); ++i) { + TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(old_ops[i], new_ops[i], + instruction_sequence)); + } + return absl::OkStatus(); +} + +absl::Status RemoveDoneOpsAndUpdateSequence( + const InstructionVector& ops, + HloInstructionSequence& instruction_sequence) { + auto remove_op = [&](HloInstruction* op) { + VLOG(10) << "op: " << op->ToString(); + TF_RETURN_IF_ERROR(RemoveOpFromParent(op)); + instruction_sequence.remove_instruction(op); + return absl::OkStatus(); + }; + for (auto op : ops) { + if (op->opcode() == HloOpcode::kTuple) { + InstructionVector to_remove; + HloInstruction* tuple_op = op; + op = MaySkipTrivialTuple(tuple_op); + to_remove.push_back(tuple_op); + for (auto opnd : tuple_op->mutable_operands()) { + to_remove.push_back(opnd); + } + for (auto opnd : to_remove) { + TF_RETURN_IF_ERROR(remove_op(opnd)); + } + } + TF_RETURN_IF_ERROR(remove_op(op)); + } + return absl::OkStatus(); +} + +bool InsertBeforeFirstCollectiveOp( + const InstructionVector& ops, + const CollectiveInComputation& collective_in_computation, + HloInstructionSequence& instruction_sequence, int64_t& idx, + int64_t& idx_tot) { + bool inserted = false; + while (idx < idx_tot) { + HloInstruction* hlo = instruction_sequence.instructions()[idx]; + if (MayInvokeCollectiveOp(hlo, collective_in_computation)) { + for (auto op : ops) { + instruction_sequence.insert_instruction(op, idx); + idx++; + idx_tot++; + } + inserted = true; + break; + } + idx++; + } + return inserted; +} + +void CopyInstructionInfo(const HloInstruction* old_op, HloInstruction* new_op) { + new_op->set_metadata(old_op->metadata()); + new_op->add_frontend_attributes(old_op->frontend_attributes()); + new_op->CopyBackendConfigFrom(old_op); +} + +HloInstruction* CreateRecvDoneFrom(const HloInstruction* old_recv_done, + HloInstruction* recv, + HloComputation* computation) { + HloInstruction* recv_done = + computation->AddInstruction(HloInstruction::CreateRecvDone( + recv, old_recv_done->channel_id().value())); + CopyInstructionInfo(old_recv_done, recv_done); + return recv_done; +} + +HloInstruction* CreateSendDoneFrom(const HloInstruction* old_send_done, + HloInstruction* send, + HloComputation* computation) { + HloInstruction* send_done = + computation->AddInstruction(HloInstruction::CreateSendDone( + send, old_send_done->channel_id().value())); + CopyInstructionInfo(old_send_done, send_done); + return send_done; +} + +absl::Status RewritePipelinedP2PWhileBody( + const CollectiveInComputation& collective_in_computation, + const std::vector& new_parameter_shapes, HloInstruction* while_op, + int64_t opnd_start, int64_t opnd_end) { + HloComputation* computation = while_op->while_body(); + HloInstruction* while_init = while_op->while_init(); + HloInstruction* root = computation->root_instruction(); + HloInstructionSequence& instruction_sequence = + computation->parent()->schedule().GetOrCreateSequence(computation); + + HloInstruction* param = computation->parameter_instruction(0); + *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes); + + InstructionVector recv_dones; + InstructionVector new_recv_dones; + InstructionVector new_send_dones; + for (int64_t i = opnd_start; i < opnd_end; ++i) { + const HloInstruction* op = root->operand(i); + op = MaySkipTrivialTuple(op); + if (op->opcode() == HloOpcode::kRecvDone) { + HloInstruction* gte = FindUniqueGTEUserWithIndex(param, i); + CHECK(gte != nullptr); + recv_dones.push_back(gte); + + // Create the new RecvDone using the new while-body parameter. + HloInstruction* recv = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(param, i)); + + HloInstruction* recv_done = CreateRecvDoneFrom(op, recv, computation); + new_recv_dones.push_back(recv_done); + continue; + } + CHECK(op->opcode() == HloOpcode::kSendDone); + // Create the new SendDone using the new while-op result. + HloInstruction* send = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(param, i)); + HloInstruction* send_done = CreateSendDoneFrom(op, send, computation); + new_send_dones.push_back(send_done); + } + TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones, + instruction_sequence)); + + // Create a new root tuple. + InstructionVector done_ops; + InstructionVector new_opnds; + for (int64_t i = 0; i < while_init->operand_count(); ++i) { + HloInstruction* op = root->mutable_operand(i); + if (i >= opnd_start && i < opnd_end) { + new_opnds.push_back(MaySkipTrivialTuple(op)->mutable_operand(0)); + done_ops.push_back(op); + } else { + new_opnds.push_back(op); + } + } + HloInstruction* new_root = + computation->AddInstruction(HloInstruction::CreateTuple(new_opnds)); + computation->set_root_instruction(new_root, + /*accept_different_shape=*/true); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(root)); + instruction_sequence.replace_instruction(root, new_root); + + TF_RETURN_IF_ERROR( + RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence)); + + // Find a place to put the new SendDone. It will be either the first + // may-invoke-collective ops that is not in the pipelined Send/Recv chain or + // the first op in the pipelined Send/Recv chain. + int64_t idx = 0; + int64_t idx_end = instruction_sequence.size(); + bool inserted = + InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation, + instruction_sequence, idx, idx_end); + CHECK(inserted); // There are Send/Recv in the while-body, expect inserted. + CHECK(idx_end == instruction_sequence.size()); + + TF_RETURN_IF_ERROR(computation->parent()->schedule().Update()); + return absl::OkStatus(); +} + +void RewritePipelinedP2PWhileCond( + const std::vector& new_parameter_shapes, HloInstruction* while_op) { + HloComputation* computation = while_op->while_condition(); + HloInstruction* param = computation->parameter_instruction(0); + *param->mutable_shape() = ShapeUtil::MakeTupleShape(new_parameter_shapes); + VLOG(10) << computation->ToString(); +} + +// Rewrites the while-op with a recognized pipelined SendDone/RecvDone pattern +// to pipeline Send/Recv instead. +absl::Status TransformLoop( + const PipelinedP2PInfo& pipelined_info, + const CollectiveInComputation& collective_in_computation, int64_t& idx, + int64_t& idx_end, HloInstructionSequence& instruction_sequence, + HloInstruction* while_op) { + HloComputation* computation = while_op->parent(); + int64_t opnd_start = pipelined_info.opnd_start; + int64_t opnd_end = pipelined_info.opnd_end; + VLOG(10) << "Transform pipelined while-op " << while_op->ToString(); + HloInstruction* while_init = while_op->while_init(); + InstructionVector new_while_init_opnds; + std::vector new_parameter_shapes; + for (int64_t i = 0; i < while_init->operand_count(); ++i) { + HloInstruction* op = while_init->mutable_operand(i); + if (i >= opnd_start && i < opnd_end) { + // Get Send/Recv from SendDone/RecvDone. + new_while_init_opnds.push_back(op->mutable_operand(0)); + } else { + new_while_init_opnds.push_back(op); + } + new_parameter_shapes.push_back(new_while_init_opnds.back()->shape()); + } + + RewritePipelinedP2PWhileCond(new_parameter_shapes, while_op); + TF_RETURN_IF_ERROR(RewritePipelinedP2PWhileBody( + collective_in_computation, new_parameter_shapes, while_op, opnd_start, + opnd_end)); + HloInstruction* new_while_init = computation->AddInstruction( + HloInstruction::CreateTuple(new_while_init_opnds), "while-init"); + VLOG(10) << "new_while_init: " << new_while_init->ToString(); + HloInstruction* new_while_op = computation->AddInstruction( + HloInstruction::CreateWhile( + while_op->while_body()->root_instruction()->shape(), + while_op->while_condition(), while_op->while_body(), new_while_init), + "while-result"); + CopyInstructionInfo(while_op, new_while_op); + VLOG(10) << "new_while_op: " << new_while_op->ToString(); + + InstructionVector recv_dones; + InstructionVector new_recv_dones; + InstructionVector new_send_dones; + InstructionVector done_ops; + for (int64_t i = opnd_start; i < opnd_end; ++i) { + HloInstruction* op = while_init->mutable_operand(i); + done_ops.push_back(op); + if (op->opcode() == HloOpcode::kRecvDone) { + HloInstruction* gte = FindUniqueGTEUserWithIndex(while_op, i); + CHECK(gte != nullptr); + recv_dones.push_back(gte); + + // Create the new RecvDone using the new while-op result. + HloInstruction* recv = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(new_while_op, i)); + HloInstruction* recv_done = computation->AddInstruction( + HloInstruction::CreateRecvDone(recv, op->channel_id().value())); + new_recv_dones.push_back(recv_done); + CopyInstructionInfo(op, recv_done); + continue; + } + CHECK(op->opcode() == HloOpcode::kSendDone); + // Create the new SendDone using the new while-op result. + HloInstruction* send = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(new_while_op, i)); + HloInstruction* send_done = computation->AddInstruction( + HloInstruction::CreateSendDone(send, op->channel_id().value())); + new_send_dones.push_back(send_done); + CopyInstructionInfo(op, send_done); + } + + TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence( + while_op, new_while_op, instruction_sequence, /*diff_shape*/ true)); + TF_RETURN_IF_ERROR( + ReplaceOpInSequence(while_init, new_while_init, instruction_sequence)); + TF_RETURN_IF_ERROR(ReplaceUsesAndUpdateSequence(recv_dones, new_recv_dones, + instruction_sequence)); + TF_RETURN_IF_ERROR( + RemoveDoneOpsAndUpdateSequence(done_ops, instruction_sequence)); + + int64_t opnd_tot = opnd_end - opnd_start; + // Verify that the numbers of ops we have removed from the sequence is + // opnd_tot and they are before the position of the new while-op. + CHECK(idx_end == instruction_sequence.size() + opnd_tot); + CHECK(instruction_sequence.instructions()[idx - opnd_tot] == new_while_op); + + // Update idx_end to reflect the current size of the instruction sequence. + // Update idx to right after the new while-op. + idx_end -= opnd_tot; + idx = idx - opnd_tot + 1; + bool inserted = + InsertBeforeFirstCollectiveOp(new_send_dones, collective_in_computation, + instruction_sequence, idx, idx_end); + CHECK(idx_end == instruction_sequence.size()); + // If there isn't any may-invoke-collective ops after the while-op, add + // the new SendDone ops before the last instruction in the sequence. + if (!inserted) { + CHECK(idx_end == idx); + idx--; + for (auto send_done : new_send_dones) { + instruction_sequence.insert_instruction(send_done, idx++); + } + } + return absl::OkStatus(); +} + +// Find while-loop with pipelined Send/Recv and rotates the SendDone/RecvDone +// for such while-loop. +absl::StatusOr ProcessComputation( + HloModule* module, HloComputation* computation, + CollectiveInComputation& collective_in_computation) { + VLOG(10) << "Process compuation " << computation->name(); + bool changed = false; + HloInstructionSequence& instruction_sequence = + module->schedule().GetOrCreateSequence(computation); + int64_t idx = 0; + int64_t idx_end = instruction_sequence.size(); + while (idx < idx_end) { + HloInstruction* hlo = instruction_sequence.instructions()[idx]; + + if (MayInvokeCollectiveOp(hlo, collective_in_computation)) { + collective_in_computation[computation] = true; + } + + if (hlo->opcode() != HloOpcode::kWhile) { + idx++; + continue; + } + + std::optional pipelined_info = FindPipelinedP2P(hlo); + if (!pipelined_info.has_value()) { + idx++; + continue; + } + TF_RETURN_IF_ERROR(TransformLoop(pipelined_info.value(), + collective_in_computation, idx, idx_end, + instruction_sequence, hlo)); + changed = true; + } + return changed; +} +} // namespace + +absl::StatusOr PipelinedP2PRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + if (!module->has_schedule()) return changed; + CollectiveInComputation collective_in_computation; + // Visit the computations in the order of callees to callers, so that + // while-body is processed before while-op. + for (auto* computation : + module->MakeComputationPostOrder(execution_threads)) { + if (computation->IsFusionComputation()) { + collective_in_computation[computation] = false; + continue; + } + + TF_ASSIGN_OR_RETURN( + bool cur_changed, + ProcessComputation(module, computation, collective_in_computation)); + changed |= cur_changed; + } + + if (changed) { + TF_RETURN_IF_ERROR(module->schedule().Update()); + } + + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h new file mode 100644 index 00000000000000..88b6bb662f2ed7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h @@ -0,0 +1,133 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ +#define XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// PipelinedP2PRewriter is a pass that rewrites pipelined Send/Recv related +// code for point-to-point communication to rotate SendDone and RecvDone at the +// end of a loop iteration to the beginning of the next iteration. This pass +// operates on scheduled module and updates the instruction sequence. +// +// In particular, a pipelined Send/Recv chain with one channel group with this +// code pattern: +// +// main: +// recv +// send +// recv-done +// send-done +// while-init = (recv-done, send-done, ...) +// while-op = while(whiel-init) ... +// +// while-body: +// ... +// recv +// send +// recv-done +// send-done +// ROOT tuple(recv-done, send-done, ...) +// +// Will be transformed to: +// +// main: +// recv +// send +// while-init = (recv, send, ...) +// while-op = while(whiel-init) ... +// recv-done +// send-done +// +// while-body: +// recv-done +// ... +// send-done +// recv +// send +// ROOT tuple(recv, send, ...) +// +// A pipelined Send/Recv chain with two channel groups with this code pattern: +// +// main: +// recv.0 +// send.0 +// recv.1 +// send.1 +// recv-done.0 +// send-done.0 +// recv-done.1 +// send-done.1 +// while-init = (recv-done.0, send-done.0, recv-done.1, send-done.1, ...) +// while-op = while(whiel-init) ... +// +// while-body: +// ... +// recv.0 +// send.0 +// recv.1 +// send.1 +// recv-done.0 +// send-done.0 +// recv-done.1 +// send-done.1 +// ROOT = tuple(recv-done.0, send-done.0, recv-done.1, send-done.1, ...) +// +// Will be transformed to: +// +// main: +// +// recv.0 +// send.0 +// recv.1 +// send.1 +// while-init = (recv.0, send.0, recv.1, send.1, ...) +// while-op = while(while-init) ... +// recv-done.0 +// send-done.0 +// recv-done.1 +// send-done.1 +// +// while-body: +// recv-done.0 +// recv-done.1 +// ... +// send-done.0 +// send-done.1 +// recv.0 +// send.1 +// recv.1 +// send.1 +// ROOT tuple(recv.0, send.0, recv.1, send.1, ...) +// +class PipelinedP2PRewriter : public HloModulePass { + public: + absl::string_view name() const override { return "pipelined-p2p-rewriter"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc new file mode 100644 index 00000000000000..373618c9e588a3 --- /dev/null +++ b/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc @@ -0,0 +1,509 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/pipelined_p2p_rewriter.h" + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class PipelinedP2pRewriterTest : public HloTestBase { + protected: + void DoFileCheck(const HloModule* module, absl::string_view expected) { + HloPrintOptions options; + options.set_print_operand_shape(false); + options.set_print_result_shape(false); + TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, + RunFileCheck(module->ToString(options), expected)); + EXPECT_TRUE(filecheck_matched); + } +}; + +TEST_F(PipelinedP2pRewriterTest, SendRecUnpipelinedNotTransform) { + const char* kModuleStr = R"( +HloModule test + +cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(11) + ROOT result = pred[] compare(count, ub), direction=LT + } + +body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = u32[2] get-tuple-element(param), index=1 + + after-all.0.n = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.n), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.0 = token[] send-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + recv-data = u32[2] get-tuple-element(recv-done.0), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond, + backend_config={"known_trip_count":{"n":"11"}} + ROOT recv-data = u32[2] get-tuple-element(while_result), index=1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + PipelinedP2PRewriter rewriter; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); + EXPECT_FALSE(changed); +} + +// Tests the rewrite for a pipelined Send/Recv chain with only one channel +// group. +TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) { + const char* kModuleStr = R"( + HloModule test, is_scheduled=true + + while-cond { + param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(25) + ROOT cond-result = pred[] compare(count, ub), direction=LT + } + + while-body { + param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0) + count = get-tuple-element(param), index=0 + + recv-done.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1 + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.q), index=0 + + c1 = u32[] constant(1) + new-count = u32[] add(count, c1) + replica = u32[] replica-id() + c10 = u32[] constant(10) + sum = u32[] add(replica, c10) + sum2 = u32[] add(sum, count) + conv = f32[] convert(sum2) + p = f32[1, 1024, 1024] broadcast(conv), dimensions={} + b = f32[1, 1024, 1024] add(p, recv-data) + c = f32[1, 1024, 1024] multiply(b, b) + d = f32[1, 1024, 1024] tan(c) + s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} + send-data = f32[1, 1024, 1024] add(c, s) + + after-all = token[] after-all() + recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all), + channel_id=1, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.p = token[] send-done(send), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + gte.0 = f32[1,1024,1024] get-tuple-element(recv-done.p), index=0 + gte.1 = token[] get-tuple-element(recv-done.p), index=1 + recv-done-tuple = (f32[1,1024,1024], token[]) tuple(gte.0, gte.1) + ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[]) + tuple(new-count, recv-done-tuple, send-done.p) + } + + ENTRY main { + c0 = u32[] constant(0) + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + + after-all.1 = token[] after-all() + recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.1.p = token[] send-done(send.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + while-init.p = (u32[], (f32[1,1024,1024], token[]), token[]) + tuple(c0, recv-done.1.p, send-done.1.p) + while-result.p = (u32[], (f32[1,1024,1024], token[]), token[]) + while(while-init.p), + body=while-body, condition=while-cond, + backend_config={"known_trip_count":{"n":"25"}} + + recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1 + + ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 + } + )"; + + const char* kExpected = R"( + CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) { + CHECK: %param.1 = parameter(0) + CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1 + CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2 + CHECK: %count.1 = get-tuple-element(%param.1), index=0 + CHECK: %recv-done = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %recv-data = get-tuple-element(%recv-done), index=0 + CHECK: %c1 = constant(1) + CHECK: %new-count = add(%count.1, %c1) + CHECK: %replica = replica-id() + CHECK: %c10 = constant(10) + CHECK: %sum = add(%replica, %c10) + CHECK: %sum2 = add(%sum, %count.1) + CHECK: %conv = convert(%sum2) + CHECK: %p = broadcast(%conv), dimensions={} + CHECK: %b = add(%p, %recv-data) + CHECK: %c = multiply(%b, %b) + CHECK: %d = tan(%c) + CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} + CHECK: %send-data = add(%c, %s) + CHECK: %after-all = after-all() + CHECK: %send-done = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} + CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} + CHECK: ROOT %tuple = tuple(%new-count, %recv, %send) + CHECK: } + + CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] { + CHECK: %param = parameter(0) + CHECK: %count = get-tuple-element(%param), index=0 + CHECK: %ub = constant(25) + CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT + CHECK: } + + CHECK: ENTRY %main () -> f32[1,1024,1024] { + CHECK: %c0 = constant(0) + CHECK: %f0 = constant(0) + CHECK: %init = broadcast(%f0), dimensions={} + CHECK: %after-all.1 = after-all() + CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} + CHECK{LITERAL}: %send.1 = send(%init, %after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} + CHECK: %while-init = tuple(%c0, %recv.1, %send.1) + CHECK: %while-result = while(%while-init), condition=%while-cond, body=%while-body, + CHECK-SAME{LITERAL}: backend_config={"known_trip_count":{"n":"25"}} + CHECK: %get-tuple-element.2 = get-tuple-element(%while-result), index=1 + CHECK: %get-tuple-element.3 = get-tuple-element(%while-result), index=2 + CHECK: %recv-done.1 = recv-done(%get-tuple-element.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %send-done.1 = send-done(%get-tuple-element.3), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: ROOT %entry-result = get-tuple-element(%recv-done.1), index=0 + CHECK: })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + PipelinedP2PRewriter rewriter; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); + EXPECT_TRUE(changed); + + DoFileCheck(module.get(), kExpected); +} + +// Tests the rewrite for a pipelined Send/Recv chain with two channel groups. +TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) { + const char* kModuleStr = R"( + HloModule test, is_scheduled=true + + while-cond { + param = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(25) + ROOT cond-result = pred[] compare(count, ub), direction=LT + } + + while-body { + param = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) parameter(0) + count = get-tuple-element(param), index=0 + + recv-done.0.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1 + recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0.q), index=0 + recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=3 + recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={} + recv-data = f32[1, 1024, 1024] select(compare, recv-data.0, recv-data.1) + + c1 = u32[] constant(1) + new-count = u32[] add(count, c1) + c10 = u32[] constant(10) + sum = u32[] add(replica, c10) + sum2 = u32[] add(sum, count) + conv = f32[] convert(sum2) + p = f32[1, 1024, 1024] broadcast(conv), dimensions={} + b = f32[1, 1024, 1024] add(p, recv-data) + c = f32[1, 1024, 1024] multiply(b, b) + d = f32[1, 1024, 1024] tan(c) + s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} + send-data = f32[1, 1024, 1024] add(c, s) + + after-all = token[] after-all() + recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all), + channel_id=1, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.p = (f32[1,1024,1024], token[]) recv-done(recv), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.p = token[] send-done(send), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + after-all.1 = token[] after-all() + recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" + } + send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1), + channel_id=2, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" + } + recv-done.1.p = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.1.p = token[] send-done(send.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + + ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) + tuple(new-count, recv-done.p, send-done.p, recv-done.1.p, send-done.1.p) + } + + ENTRY main { + c0 = u32[] constant(0) + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + + after-all.2 = token[] after-all() + recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.2.p = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.2.p = token[] send-done(send.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + + after-all.3 = token[] after-all() + recv.3 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" + } + send.3 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" + } + recv-done.3.p = (f32[1,1024,1024], token[]) recv-done(recv.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.3.p = token[] send-done(send.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + + while-init.p = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) tuple(c0, recv-done.2.p, send-done.2.p, recv-done.3.p, send-done.3.p) + while-result.p = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) while(while-init.p), + body=while-body, condition=while-cond, + backend_config={"known_trip_count":{"n":"25"}} + + recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=1 + recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0 + recv-done.3.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result.p), index=3 + recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3.q), index=0 + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[1, 1024, 1024] broadcast(compare0), dimensions={} + ROOT entry-result = f32[1, 1024, 1024] select(compare, recv-data.2, recv-data.3) + } + )"; + + const char* kExpected = R"( + CHECK: %while-body (param.1: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[])) { + CHECK: %param.1 = parameter(0) + CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1 + CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2 + CHECK: %get-tuple-element.2 = get-tuple-element(%param.1), index=3 + CHECK: %get-tuple-element.3 = get-tuple-element(%param.1), index=4 + CHECK: %count.1 = get-tuple-element(%param.1), index=0 + CHECK: %recv-done = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %recv-data.0 = get-tuple-element(%recv-done), index=0 + CHECK: %recv-done.1 = recv-done(%get-tuple-element.2), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %recv-data.1 = get-tuple-element(%recv-done.1), index=0 + CHECK: %replica = replica-id() + CHECK: %constant0 = constant(0) + CHECK: %compare0 = compare(%replica, %constant0), direction=EQ + CHECK: %compare = broadcast(%compare0), dimensions={} + CHECK: %recv-data.2 = select(%compare, %recv-data.0, %recv-data.1) + CHECK: %c1 = constant(1) + CHECK: %new-count = add(%count.1, %c1) + CHECK: %c10 = constant(10) + CHECK: %sum = add(%replica, %c10) + CHECK: %sum2 = add(%sum, %count.1) + CHECK: %conv = convert(%sum2) + CHECK: %p = broadcast(%conv), dimensions={} + CHECK: %b = add(%p, %recv-data.2) + CHECK: %c = multiply(%b, %b) + CHECK: %d = tan(%c) + CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} + CHECK: %send-data = add(%c, %s) + CHECK: %after-all = after-all() + CHECK: %send-done = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %send-done.1 = send-done(%get-tuple-element.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} + CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} + CHECK: %after-all.1 = after-all() + CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} + CHECK{LITERAL}: %send.1 = send(%send-data, %after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} + CHECK: ROOT %tuple = tuple(%new-count, %recv, %send, %recv.1, %send.1) + CHECK: } + + CHECK: %while-cond (param: (u32[], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]))) -> pred[] { + CHECK: %param = parameter(0) + CHECK: %count = get-tuple-element(%param), index=0 + CHECK: %ub = constant(25) + CHECK: ROOT %cond-result = compare(%count, %ub), direction=LT + CHECK: } + + CHECK: ENTRY %main () -> f32[1,1024,1024] { + CHECK: %c0 = constant(0) + CHECK: %f0 = constant(0) + CHECK: %init = broadcast(%f0), dimensions={} + CHECK: %after-all.2 = after-all() + CHECK{LITERAL}: %recv.2 = recv(%after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} + CHECK{LITERAL}: %send.2 = send(%init, %after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} + CHECK: %after-all.3 = after-all() + CHECK{LITERAL}: %recv.3 = recv(%after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} + CHECK{LITERAL}: %send.3 = send(%init, %after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} + CHECK: %while-init = tuple(%c0, %recv.2, %send.2, %recv.3, %send.3) + CHECK{LITERAL}: %while-result = while(%while-init), condition=%while-cond, body=%while-body, backend_config={"known_trip_count":{"n":"25"}} + CHECK: %get-tuple-element.4 = get-tuple-element(%while-result), index=1 + CHECK: %get-tuple-element.5 = get-tuple-element(%while-result), index=2 + CHECK: %get-tuple-element.6 = get-tuple-element(%while-result), index=3 + CHECK: %get-tuple-element.7 = get-tuple-element(%while-result), index=4 + CHECK: %recv-done.2 = recv-done(%get-tuple-element.4), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %recv-data.3 = get-tuple-element(%recv-done.2), index=0 + CHECK: %recv-done.3 = recv-done(%get-tuple-element.6), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %recv-data.4 = get-tuple-element(%recv-done.3), index=0 + CHECK: %replica.1 = replica-id() + CHECK: %constant0.1 = constant(0) + CHECK: %compare0.1 = compare(%replica.1, %constant0.1), direction=EQ + CHECK: %compare.1 = broadcast(%compare0.1), dimensions={} + CHECK: %send-done.2 = send-done(%get-tuple-element.5), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %send-done.3 = send-done(%get-tuple-element.7), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: ROOT %entry-result = select(%compare.1, %recv-data.3, %recv-data.4) + CHECK: })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + PipelinedP2PRewriter rewriter; + TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); + EXPECT_TRUE(changed); + + DoFileCheck(module.get(), kExpected); +} + +} // namespace +} // namespace gpu +} // namespace xla