Skip to content

Commit

Permalink
manual dims must be excluded when computing the new tile assignment, or
Browse files Browse the repository at this point in the history
num_new_tiles will always be > NumTiles() + 1 and sharding propagation will
infinitely reassign the same partial manual shardings.

PiperOrigin-RevId: 632669667
  • Loading branch information
pschuh authored and tensorflower-gardener committed May 11, 2024
1 parent acb4ea8 commit 7d61a7b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
19 changes: 13 additions & 6 deletions third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ bool MergeShardingIfCompatible(const HloSharding& to_merge,
}
}

const int64_t num_devices = to_merge.tile_assignment().num_elements();
const int64_t new_num_tiles = Product(merged_tile_dims);
if (num_devices % new_num_tiles != 0 || new_num_tiles < minimum_tiles) {
return false;
}
int64_t replication;

if (to_merge_man_dim >= 0) {
int64_t man_group_size = to_merge.tile_assignment().dim(to_merge_man_dim);
if (man_group_size != dst->tile_assignment().dim(dst_man_dim)) {
Expand All @@ -365,14 +372,14 @@ bool MergeShardingIfCompatible(const HloSharding& to_merge,
merged_tile_dims.push_back(man_group_size);
num_merge_groups *= man_group_size;
num_dst_groups *= man_group_size;
if (num_devices % (new_num_tiles * man_group_size) != 0) {
return false;
}
replication = num_devices / (new_num_tiles * man_group_size);
} else {
replication = num_devices / new_num_tiles;
}

const int64_t num_devices = to_merge.tile_assignment().num_elements();
const int64_t new_num_tiles = Product(merged_tile_dims);
if (num_devices % new_num_tiles != 0 || new_num_tiles < minimum_tiles) {
return false;
}
const int64_t replication = num_devices / new_num_tiles;
if (replication > 1) {
merged_tile_dims.push_back(replication);
}
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,17 @@ TEST(HloShardingUtilTest,
EXPECT_EQ(result, output_sharding);
}

TEST(HloShardingUtilTest, MergeManualSubgroupSharding) {
TileAssignment tile_assignment({16, 4});
std::vector<OpSharding::Type> subgroup_types = {OpSharding::MANUAL,
OpSharding::REPLICATED};
// Subgroup sharding
// {devices=[16,4]<=[64] last_tile_dims={manual, replicated}}
HloSharding dst = HloSharding::Subgroup(tile_assignment, subgroup_types);
HloSharding to_merge = dst;
EXPECT_FALSE(MergeShardingIfCompatible(to_merge, dst.NumTiles() + 1, &dst));
}

TEST(HloShardingUtilTest, GetManualSubgroupSharding_ManualOnly) {
TileAssignment tile_assignment({1, 2, 2});
std::vector<OpSharding::Type> subgroup_types = {OpSharding::MANUAL};
Expand Down

0 comments on commit 7d61a7b

Please sign in to comment.