Skip to content

Commit

Permalink
Reverts d9417b2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627795237
  • Loading branch information
dicentra13 authored and tensorflower-gardener committed Apr 24, 2024
1 parent f497887 commit e30abe6
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 210 deletions.
144 changes: 55 additions & 89 deletions third_party/xla/xla/hlo/utils/hlo_sharding_util.cc
Expand Up @@ -18,11 +18,9 @@ limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <iterator>
#include <map>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <tuple>
Expand Down Expand Up @@ -693,34 +691,30 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
return sharding;
}

// In case of a tiled sharding, the reshaped sharding will be valid if the
// In case of a tiled sharding the reshaped sharding will be a valid if the
// reshape is composed from the following operations:
// * Adding or removing dimensions with size 1.
// * Merging consecutive dimensions where only the most major is sharded.
// * Splitting a dimension to consecutive dimensions.
// * Any reshaping of unsharded dimensions.
//
// Merge and split can happen consecutively on the same dimension, e.g.,
// f32[1024,256] to f32[128,2048] can be considered that 1024 gets split into
// 128 and 8, but 8 then gets merged with 256. We use stacks to make
// supporting such cases easy.
//
// If transpose is needed between source and target shapes, we use the GCD of
// (target_shape_dim, sharding_dim) if source_shape_dim % sharding_dim == 0.
// For example, given the source_shape f32[6,4], target_shape f32[4,6] and
// sharding {devices=[6,1]<=[6]}, the output sharding is {devices=[2,1,3]<=[6]
// last_tile_dim_replicate}.
// Note that merge and split can happen consecutively on the same dimension,
// e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
// gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
// to make supporting such cases easy.
const Shape tile_shape = sharding.TileShape(source_shape);
DimensionVector target_tile_assignment_dimensions;
DimensionVector source_dims_stack(source_shape.dimensions().rbegin(),
source_shape.dimensions().rend());
DimensionVector target_dims_stack(target_shape.dimensions().rbegin(),
target_shape.dimensions().rend());
DimensionVector sharding_tile_dims_stack(
sharding.tile_assignment().dimensions().begin(),
sharding.tile_assignment().dimensions().begin() + source_shape.rank());
std::reverse(sharding_tile_dims_stack.begin(),
sharding_tile_dims_stack.end());

DimensionVector source_dims_stack(source_shape.rank());
DimensionVector target_dims_stack(target_shape.rank());
DimensionVector sharding_tile_dims_stack(source_shape.rank());
int64_t added_to_partially_replicated = 1;
for (int64_t i = 0; i < source_shape.rank(); ++i) {
source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
sharding_tile_dims_stack[i] =
sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
}
for (int64_t i = 0; i < target_shape.rank(); ++i) {
target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
}
bool inplace_add_sharding_dim = false;
auto append_sharding_dim = [&](int64_t size) {
if (inplace_add_sharding_dim) {
Expand All @@ -730,49 +724,29 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
}
inplace_add_sharding_dim = false;
};

while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
if (Product(sharding_tile_dims_stack) == 1) {
// No more partitions left.
break;
}

int64_t source_dim_product = 1;
while (!sharding_tile_dims_stack.empty() &&
sharding_tile_dims_stack.back() == 1) {
sharding_tile_dims_stack.pop_back();
source_dim_product *= source_dims_stack.back();
source_dims_stack.pop_back();
}
while (!target_dims_stack.empty() &&
source_dim_product % target_dims_stack.back() == 0) {
source_dim_product /= target_dims_stack.back();
target_dims_stack.pop_back();
append_sharding_dim(1);
}
if (source_dim_product != 1) {
source_dims_stack.push_back(source_dim_product);
sharding_tile_dims_stack.push_back(1);
}
if (source_dims_stack.empty() && target_dims_stack.empty()) {
if (target_dims_stack.empty()) {
if (Product(sharding_tile_dims_stack) != 1) {
return std::nullopt;
}
break;
}

int64_t s_size = 1;
int64_t t_size = 1;
int64_t s_partitions = 1;
if (!source_dims_stack.empty()) {
s_size = source_dims_stack.back();
source_dims_stack.pop_back();
s_partitions = sharding_tile_dims_stack.back();
sharding_tile_dims_stack.pop_back();
}

if (target_dims_stack.empty()) {
return std::nullopt;
}
int64_t t_size = target_dims_stack.back();
t_size = target_dims_stack.back();
target_dims_stack.pop_back();

if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
// No more partitions left.
append_sharding_dim(1);
continue;
}
if (s_partitions > 1 && s_size % s_partitions == 0 &&
t_size % s_partitions == 0) {
// If s_partitions evenly divides both s_size and t_size, we can add this
Expand All @@ -787,19 +761,22 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
if (s_size == t_size) {
// Same dimension.
append_sharding_dim(s_partitions);
} else if (t_size == 1) {
// Trivial dimension added.
append_sharding_dim(1);
source_dims_stack.push_back(s_size);
sharding_tile_dims_stack.push_back(s_partitions);
} else if (s_size == 1) {
// Trivial dimension removed.
if (s_partitions != 1) {
added_to_partially_replicated *= s_partitions;
}
target_dims_stack.push_back(t_size);
} else if (s_size > t_size) {
// Dimension split.
if (s_size % s_partitions != 0) {
if (s_size % t_size != 0 || s_size % s_partitions != 0) {
return std::nullopt;
}
if (s_size % t_size != 0) {
// Transpose is needed between source and target shapes.
append_sharding_dim(std::gcd(t_size, s_partitions));
break;
}
if (t_size % s_partitions == 0) {
append_sharding_dim(s_partitions);
// We have part of the s_size unprocessed, so put it back to stack.
Expand All @@ -811,7 +788,7 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(s_partitions / t_size);
} else {
break;
return std::nullopt;
}
} else {
// Dimension merge. Also merge the source dimension with the next, and
Expand All @@ -820,56 +797,45 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
return std::nullopt;
}
CHECK(!source_dims_stack.empty());
if (t_size % s_size != 0) {
// Transpose is needed between source and target shapes.
append_sharding_dim(std::gcd(t_size, s_partitions));
break;
}
if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
// If the next dimension to combine is sharded, we require that the
// current dimension's shard size to be 1. Otherwise, the new shard
// would be non-contiguous.
break;
return std::nullopt;
}
source_dims_stack.back() *= s_size;
sharding_tile_dims_stack.back() *= s_partitions;
target_dims_stack.push_back(t_size);
}
}

if (Product(target_tile_assignment_dimensions) == 1) {
return std::nullopt;
}
while (target_tile_assignment_dimensions.size() < target_shape.rank()) {
target_tile_assignment_dimensions.push_back(1);
}
for (int64_t i = sharding.TiledDataRank();
i < sharding.tile_assignment().num_dimensions(); ++i) {
target_tile_assignment_dimensions.push_back(
i == sharding.SubgroupReplicationDim()
? 1
: sharding.tile_assignment().dim(i));
sharding.tile_assignment().dim(i));
}

auto subgroup_types = sharding.subgroup_types();
auto partially_replicated = std::div(
sharding.TotalNumTiles(), Product(target_tile_assignment_dimensions));
CHECK_EQ(partially_replicated.rem, 0);
if (partially_replicated.quot > 1) {
// If we added dimensions to the partially replicated dimension then add the
// additional dimension on the partially replicated tiling.
if (added_to_partially_replicated > 1) {
if (sharding.ReplicateOnLastTileDim()) {
target_tile_assignment_dimensions.back() = partially_replicated.quot;
subgroup_types.push_back(OpSharding::REPLICATED);
} else if (absl::c_linear_search(subgroup_types, OpSharding::REPLICATED)) {
target_tile_assignment_dimensions[sharding.SubgroupReplicationDim() -
sharding.TiledDataRank() +
target_shape.rank()] =
partially_replicated.quot;
target_tile_assignment_dimensions.back() *= added_to_partially_replicated;
} else {
target_tile_assignment_dimensions.push_back(partially_replicated.quot);
subgroup_types.push_back(OpSharding::REPLICATED);
target_tile_assignment_dimensions.push_back(
added_to_partially_replicated);
}
}

// If subgroup_types doesn't have already partially replicated as a sharding
// type then add it.
if ((sharding.ReplicateOnLastTileDim() ||
added_to_partially_replicated > 1) &&
(subgroup_types.empty() ||
subgroup_types.back() != OpSharding::REPLICATED)) {
subgroup_types.push_back(OpSharding::REPLICATED);
}
auto new_tile_assignment =
sharding.tile_assignment().Reshape(target_tile_assignment_dimensions);
return HloSharding::Subgroup(new_tile_assignment, subgroup_types,
Expand Down
75 changes: 0 additions & 75 deletions third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc
Expand Up @@ -239,48 +239,6 @@ TEST(HloShardingUtilTest, ReshapeShardingScalar) {
EXPECT_FALSE(result.has_value());
}

TEST(HloShardingUtilTest, ReshapeShardingTranspose1) {
Shape input_shape = ShapeUtil::MakeShape(F32, {6, 2, 5});
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 3, 5});
HloSharding sharding = HloSharding::IotaTile({2, 1, 5});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), sharding);
}

TEST(HloShardingUtilTest, ReshapeShardingTranspose2) {
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5, 7, 11});
Shape output_shape = ShapeUtil::MakeShape(F32, {10, 21, 11});
HloSharding input_sharding = HloSharding::IotaTile({2, 1, 1, 1, 13});
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 13});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
}

TEST(HloShardingUtilTest, ReshapeShardingTranspose3) {
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5});
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 10});
HloSharding input_sharding = HloSharding::IotaTile({1, 1, 5});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_FALSE(result.has_value());
}

TEST(HloShardingUtilTest, ReshapeShardingTranspose4) {
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5, 7, 11, 13, 17, 19});
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 2, 55, 91, 19, 17});
HloSharding input_sharding = HloSharding::IotaTile({1, 1, 5, 1, 1, 13, 1, 1});
HloSharding output_sharding =
HloSharding::PartialTile(TileAssignment({1, 1, 5, 1, 1, 1, 13}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
}

TEST(HloShardingUtilTest, ReshapeToTileDimension2D) {
// The two sharding in the vector are the same. They will be processed in
// different branches in ReshapeToTileDimension.
Expand Down Expand Up @@ -379,39 +337,6 @@ TEST(HloShardingUtilTest, ReshapeToTileDimension4D) {
}
}

TEST(HloShardingUtilTest, PropagateReshapeShardingTranspose1) {
Shape input_shape = ShapeUtil::MakeShape(F32, {6, 4});
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 2});
HloSharding input_sharding = HloSharding::IotaTile({6, 1});
HloSharding output_sharding =
HloSharding::PartialTile(TileAssignment({2, 1, 1, 1, 3}));
HloSharding result = PropagateShardingThroughReshape(
input_shape, output_shape, input_sharding);
EXPECT_EQ(result, output_sharding);
}

TEST(HloShardingUtilTest, PropagateReshapeShardingTranspose2) {
Shape input_shape = ShapeUtil::MakeShape(F32, {6, 4});
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 6});
HloSharding input_sharding = HloSharding::IotaTile({6, 1});
HloSharding output_sharding =
HloSharding::PartialTile(TileAssignment({2, 1, 3}));
HloSharding result = PropagateShardingThroughReshape(
input_shape, output_shape, input_sharding);
EXPECT_EQ(result, output_sharding);
}

TEST(HloShardingUtilTest, PropagateReshapeShardingTranspose3) {
Shape input_shape = ShapeUtil::MakeShape(F32, {4, 6, 5});
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 5, 3});
HloSharding input_sharding = HloSharding::IotaTile({2, 6, 1});
HloSharding output_sharding =
HloSharding::PartialTile(TileAssignment({2, 1, 2, 1, 1, 3}));
HloSharding result = PropagateShardingThroughReshape(
input_shape, output_shape, input_sharding);
EXPECT_EQ(result, output_sharding);
}

TEST(HloShardingUtilTest, PropagateReshapeShardingTiledSplitPartialMatch) {
Shape input_shape = ShapeUtil::MakeShape(F32, {14, 16});
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 7, 4, 4});
Expand Down
46 changes: 0 additions & 46 deletions third_party/xla/xla/service/sharding_propagation_test.cc
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#include <ostream>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -1508,51 +1507,6 @@ ENTRY %reshape {
}
}

TEST_P(ParameterizedMetadataTest, ReshapeForwardPassTranspose1) {
const char* const hlo_string = R"(
HloModule module
ENTRY %reshape {
%param0 = f32[6,4,5] parameter(0), sharding={devices=[6,2,1]<=[12] metadata={op_name="a"}}
%reshape.1 = f32[2,3,20] reshape(%param0)
%reshape.2 = f32[2,4,3,5] reshape(%param0)
%reshape.3 = f32[20,6] reshape(%param0)
%reshape.4 = f32[3,5,8] reshape(%param0)
%reshape.5 = f32[10,4,3] reshape(%param0)
%reshape.6 = f32[5,8,3] reshape(%param0)
ROOT %tuple = tuple(%reshape.1, %reshape.2, %reshape.3, %reshape.4, %reshape.5, %reshape.6)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
if (GetParam().clear_metadata) {
ClearMetadata(module.get());
}
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/false, GetParam().propagate_metadata)
.Run(module.get()));
XLA_VLOG_LINES(1, module->ToString());
EXPECT_TRUE(changed);

std::vector<std::pair<std::string, std::string>> instruction_and_sharding = {
{"reshape.1", "{devices=[2,3,2]<=[12]}"},
{"reshape.2", "{devices=[2,1,1,1,6]<=[12] last_tile_dim_replicate}"},
{"reshape.3", "{devices=[2,1,6]<=[12] last_tile_dim_replicate}"},
{"reshape.4", "{devices=[3,1,1,4]<=[12] last_tile_dim_replicate}"},
{"reshape.5", "{devices=[2,1,1,6]<=[12] last_tile_dim_replicate}"},
{"reshape.6", "{replicated}"}};
for (const auto& [name, sharding] : instruction_and_sharding) {
auto* instruction = FindInstruction(module.get(), name);
ASSERT_NE(instruction, nullptr);
EXPECT_THAT(instruction, op::Sharding(sharding));
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
EXPECT_THAT(instruction->sharding(),
ShardingMetadata({CreateMetadata("a")}));
} else {
EXPECT_THAT(instruction->sharding(), ShardingMetadata({}));
}
}
}

TEST_P(ParameterizedMetadataTest, ReshapeBackwardPass) {
const char* const hlo_string = R"(
HloModule module
Expand Down

0 comments on commit e30abe6

Please sign in to comment.