Skip to content

Commit

Permalink
[XLA:GPU][NFC] Add and refactor GPU reduce-scatter-creator tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636472439
  • Loading branch information
golechwierowicz authored and tensorflower-gardener committed May 23, 2024
1 parent 14b89c4 commit 76d56bf
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ xla_cc_test(
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:statusor",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ limitations under the License.
#include <memory>
#include <utility>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/algorithm/container.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
Expand Down Expand Up @@ -61,8 +64,18 @@ class GpuReduceScatterCreatorTest : public HloTestBase {
}

size_t AllReduceCount(std::unique_ptr<HloModule> &module) {
return absl::c_count_if(module->entry_computation()->instructions(),
HloPredicateIsOp<HloOpcode::kAllReduce>);
return CollectiveCount(module, HloOpcode::kAllReduce);
}

size_t ReduceScatterCount(std::unique_ptr<HloModule> &module) {
return CollectiveCount(module, HloOpcode::kAllReduce);
}

private:
size_t CollectiveCount(std::unique_ptr<HloModule> &module, HloOpcode opcode) {
return absl::c_count_if(
module->entry_computation()->instructions(),
[&opcode](HloInstruction *instr) { return instr->opcode() == opcode; });
}
};

Expand Down Expand Up @@ -382,6 +395,41 @@ ENTRY %AllReduce {
EXPECT_EQ(AllReduceCount(module), 0);
}

TEST_F(GpuReduceScatterCreatorTest, AllReduceFollowedByAllReduce) {
absl::string_view hlo_string = R"(
HloModule AllReduce
%sum {
%a = f32[] parameter(0)
%b = f32[] parameter(1)
ROOT %add = f32[] add(%a, %b)
}
ENTRY %AllReduce {
%param = f32[32,8,128]{2,1,0} parameter(0)
%all-reduce.scattered = f32[32,8,128]{2,1,0} all-reduce(%param),
replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=1
%table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
%pid = u32[] partition-id()
%id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
%reshape = s32[] reshape(%id)
%slice_size = s32[] constant(4)
%offset = s32[] multiply(%reshape, %slice_size)
%zero = s32[] constant(0)
%dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce.scattered, %offset, %zero, %zero),
dynamic_slice_sizes={4,8,128}
ROOT %all-reduce.sync = f32[4,8,128]{2,1,0} all-reduce(%dynamic-slice),
replica_groups={{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}}, to_apply=%sum, use_global_device_ids=true, channel_id=2
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/2,
/*num_partitions=*/8,
/*expect_change=*/true));
EXPECT_EQ(AllReduceCount(module), 1);
EXPECT_EQ(ReduceScatterCount(module), 1);
}

TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobals) {
absl::string_view hlo_string = R"(
HloModule AllReduce
Expand Down

0 comments on commit 76d56bf

Please sign in to comment.