Skip to content

Commit

Permalink
[IFRT] Add xla::ifrt::Sharding::IsFullyReplicated()
Browse files Browse the repository at this point in the history
IFRT Sharding type gains `IsFullyReplicated()`, which quickly tells if the
sharding represents a fully-replicated sharding.

The main motivation is to make full replication information queriable at IFRT
shardings and prepare for enabling IFRT implementations to handle full
replication directly.

There are a preset of rules:

* `SingleDeviceSharding` is trivially fully replicated by its definition.
* `ConcreteSharding` and `OpaqueSharding` is not fully replicated. They have special cases where it may be fully replicated, but the user is advised to use a more specific sharding type to represent such cases.
* `ConcreteEvenSharding` may/may not fully replicated. This is controlled at creation time.
* `ShardingParamSharding` and (IFRT) `HloSharding` depend on whether their lower-level sharding represents full replication.

`ConcreteEvenSharding` is a noteworthy case where the full replication information
does not come from the existing source of the information. This is because the
creators of this sharding (e.g., JAX) typically has the information, but the
replication information is lost when coercing it into `ConcreteEvenSharding`.
This problem will be gradually less problematic once JAX uses a higher-level
IFRT sharding type (mainly (IFRT) `HloSharding`) at more places.

This change extends the `Sharding` type, but the new method is not used by any
existing code.

PiperOrigin-RevId: 635667260
  • Loading branch information
hyeontaek authored and tensorflower-gardener committed May 23, 2024
1 parent 4d210c6 commit aed4341
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 45 deletions.
43 changes: 29 additions & 14 deletions third_party/xla/xla/python/ifrt/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ namespace ifrt {

namespace {

// Returns if `sharding_param` indicates a fully replicated sharding.
bool ComputeIsFullyReplicated(const ShardingParam& sharding_param) {
return llvm::all_of(sharding_param.dim_shards(),
[](auto shards) { return shards == 1; });
}

// Iterates the major-to-minor Cartesian product of a Span of containers of the
// same type.
//
Expand Down Expand Up @@ -229,8 +235,8 @@ std::unique_ptr<OpaqueSharding> OpaqueSharding::Create(DeviceList devices,
}

OpaqueSharding::OpaqueSharding(DeviceList devices, MemoryKind memory_kind)
: llvm::RTTIExtends<OpaqueSharding, Sharding>(std::move(devices),
memory_kind) {}
: llvm::RTTIExtends<OpaqueSharding, Sharding>(
std::move(devices), memory_kind, /*is_fully_replicated=*/false) {}

absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
OpaqueSharding::Disassemble(const Shape& shape) const {
Expand Down Expand Up @@ -285,16 +291,16 @@ std::unique_ptr<ConcreteSharding> ConcreteSharding::Create(

ConcreteSharding::ConcreteSharding(DeviceList devices, MemoryKind memory_kind,
Shape shape, std::vector<Shape> shard_shapes)
: llvm::RTTIExtends<ConcreteSharding, Sharding>(std::move(devices),
memory_kind),
: llvm::RTTIExtends<ConcreteSharding, Sharding>(
std::move(devices), memory_kind, /*is_fully_replicated=*/false),
shape_(std::move(shape)),
shard_shapes_(std::move(shard_shapes)) {}

ConcreteSharding::ConcreteSharding(
DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape,
std::vector<DynamicShape> shard_dynamic_shapes)
: llvm::RTTIExtends<ConcreteSharding, Sharding>(std::move(devices),
memory_kind),
: llvm::RTTIExtends<ConcreteSharding, Sharding>(
std::move(devices), memory_kind, /*is_fully_replicated=*/false),
shape_(std::move(dynamic_shape)),
shard_shapes_(std::move(shard_dynamic_shapes)) {}

Expand Down Expand Up @@ -381,18 +387,19 @@ std::string ConcreteSharding::DebugString() const {
}

std::unique_ptr<ConcreteEvenSharding> ConcreteEvenSharding::Create(
DeviceList devices, MemoryKind memory_kind, Shape shape,
Shape shard_shape) {
return std::unique_ptr<ConcreteEvenSharding>(
new ConcreteEvenSharding(std::move(devices), memory_kind,
std::move(shape), std::move(shard_shape)));
DeviceList devices, MemoryKind memory_kind, Shape shape, Shape shard_shape,
bool is_fully_replicated) {
return std::unique_ptr<ConcreteEvenSharding>(new ConcreteEvenSharding(
std::move(devices), memory_kind, std::move(shape), std::move(shard_shape),
is_fully_replicated));
}

ConcreteEvenSharding::ConcreteEvenSharding(DeviceList devices,
MemoryKind memory_kind, Shape shape,
Shape shard_shape)
: llvm::RTTIExtends<ConcreteEvenSharding, Sharding>(std::move(devices),
memory_kind),
Shape shard_shape,
bool is_fully_replicated)
: llvm::RTTIExtends<ConcreteEvenSharding, Sharding>(
std::move(devices), memory_kind, is_fully_replicated),
shape_(std::move(shape)),
shard_shape_(std::move(shard_shape)) {}

Expand Down Expand Up @@ -459,6 +466,14 @@ ShardingParamSharding::Create(ShardingParam sharding_param, DeviceList devices,
std::move(sharding_param), std::move(devices), memory_kind));
}

ShardingParamSharding::ShardingParamSharding(ShardingParam sharding_param,

DeviceList devices,
MemoryKind memory_kind)
: llvm::RTTIExtends<ShardingParamSharding, Sharding>(
devices, memory_kind, ComputeIsFullyReplicated(sharding_param)),
sharding_param_(sharding_param) {}

absl::StatusOr<std::vector<std::pair<Shape, std::shared_ptr<const Sharding>>>>
ShardingParamSharding::Disassemble(const Shape& shape) const {
DCHECK(this);
Expand Down
32 changes: 19 additions & 13 deletions third_party/xla/xla/python/ifrt/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class Sharding : public llvm::RTTIExtends<Sharding, Serializable> {
// Returns the memory kind for all shards in this sharding.
MemoryKind memory_kind() const { return memory_kind_; }

// Returns if this sharding is fully replicated. A fully replicated sharding
// means that the logical shape and shard shapes are identical, and every
// shard of the array contains the entire data of the logical array.
bool IsFullyReplicated() const { return is_fully_replicated_; }

// Breaks a shape up into per-device shapes and shardings. See
// Array::DisassembleIntoSingleDeviceArrays(). It may return an error if
// disassembly is unsupported.
Expand Down Expand Up @@ -94,11 +99,14 @@ class Sharding : public llvm::RTTIExtends<Sharding, Serializable> {
static char ID; // NOLINT

protected:
Sharding(DeviceList devices, MemoryKind memory_kind)
: devices_(devices), memory_kind_(memory_kind) {}
Sharding(DeviceList devices, MemoryKind memory_kind, bool is_fully_replicated)
: devices_(devices),
memory_kind_(memory_kind),
is_fully_replicated_(is_fully_replicated) {}

DeviceList devices_;
MemoryKind memory_kind_;
bool is_fully_replicated_;
};

std::ostream& operator<<(std::ostream& os, const Sharding& sharding);
Expand Down Expand Up @@ -138,8 +146,8 @@ class SingleDeviceSharding final

private:
explicit SingleDeviceSharding(Device* device, MemoryKind memory_kind)
: llvm::RTTIExtends<SingleDeviceSharding, Sharding>(DeviceList({device}),
memory_kind) {}
: llvm::RTTIExtends<SingleDeviceSharding, Sharding>(
DeviceList({device}), memory_kind, /*is_fully_replicated=*/true) {}
};

// Opaque sharding that does not define a fixed semantics for conversion between
Expand Down Expand Up @@ -261,10 +269,11 @@ class ConcreteEvenSharding
: public llvm::RTTIExtends<ConcreteEvenSharding, Sharding> {
public:
// Creates a concrete even sharding.
static std::unique_ptr<ConcreteEvenSharding> Create(DeviceList devices,
MemoryKind memory_kind,
Shape shape,
Shape shard_shape);
// TODO(hyeontaek): Remove the default value of `is_fully_replicated` once all
// callers are updated to provide it explicitly.
static std::unique_ptr<ConcreteEvenSharding> Create(
DeviceList devices, MemoryKind memory_kind, Shape shape,
Shape shard_shape, bool is_fully_replicated = false);

Shape shape() const {
DCHECK(this);
Expand Down Expand Up @@ -294,7 +303,7 @@ class ConcreteEvenSharding

private:
ConcreteEvenSharding(DeviceList devices, MemoryKind memory_kind, Shape shape,
Shape shard_shape);
Shape shard_shape, bool is_fully_replicated);

Shape shape_;
Shape shard_shape_;
Expand Down Expand Up @@ -324,10 +333,7 @@ class ShardingParamSharding

private:
ShardingParamSharding(ShardingParam sharding_param, DeviceList devices,
MemoryKind memory_kind)
: llvm::RTTIExtends<ShardingParamSharding, Sharding>(devices,
memory_kind),
sharding_param_(sharding_param) {}
MemoryKind memory_kind);

ShardingParam sharding_param_;
};
Expand Down
11 changes: 8 additions & 3 deletions third_party/xla/xla/python/ifrt/sharding_serdes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ limitations under the License.
==============================================================================*/

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/serdes.h"
Expand Down Expand Up @@ -223,6 +227,7 @@ class ConcreteEvenShardingSerDes
}
*proto.mutable_shape() = sharding.shape().ToProto();
*proto.mutable_shard_shape() = sharding.shard_shape().ToProto();
proto.set_is_fully_replicated(sharding.IsFullyReplicated());
return proto.SerializeAsString();
}

Expand All @@ -248,9 +253,9 @@ class ConcreteEvenShardingSerDes
TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape()));
TF_ASSIGN_OR_RETURN(auto shard_shape,
Shape::FromProto(proto.shard_shape()));
return ConcreteEvenSharding::Create(std::move(devices), memory_kind,
std::move(shape),
std::move(shard_shape));
return ConcreteEvenSharding::Create(
std::move(devices), memory_kind, std::move(shape),
std::move(shard_shape), proto.is_fully_replicated());
}

static char ID; // NOLINT
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/python/ifrt/sharding_serdes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ message ConcreteEvenShardingProto {
optional string memory_kind = 4;
ShapeProto shape = 2;
ShapeProto shard_shape = 3;
bool is_fully_replicated = 5;
}
9 changes: 5 additions & 4 deletions third_party/xla/xla/python/ifrt/sharding_serdes_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ TEST_P(ShardingSerDesTest, ConcreteShardingWithDynamicShapeRoundTrip) {
}

TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) {
auto sharding =
ConcreteEvenSharding::Create(GetDevices({0, 1}), MemoryKind("abc"),
/*shape=*/Shape({10, 20}),
/*shard_shape=*/Shape({5, 20}));
auto sharding = ConcreteEvenSharding::Create(
GetDevices({0, 1}), MemoryKind("abc"),
/*shape=*/Shape({10, 20}),
/*shard_shape=*/Shape({5, 20}), /*is_fully_replicated=*/true);

TF_ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(*sharding));

Expand All @@ -134,6 +134,7 @@ TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) {
EXPECT_THAT(out_sharding->devices(), ElementsAreArray(sharding->devices()));
EXPECT_THAT(out_sharding->shape(), sharding->shape());
EXPECT_THAT(out_sharding->shard_shape(), sharding->shard_shape());
EXPECT_THAT(out_sharding->IsFullyReplicated(), sharding->IsFullyReplicated());
}

INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest,
Expand Down
89 changes: 83 additions & 6 deletions third_party/xla/xla/python/ifrt/sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class ConcreteShardingTest : public test_util::ShardingTest {};
class ConcreteEvenShardingTest : public test_util::ShardingTest {};
class ShardingParamShardingTest : public test_util::ShardingTest {};

TEST_P(SingleDeviceShardingTest, IsFullyReplicated) {
auto device_list = GetDevices({0});
std::shared_ptr<const Sharding> sharding =
SingleDeviceSharding::Create(device_list.devices().front(), MemoryKind());
EXPECT_TRUE(sharding->IsFullyReplicated());
}

TEST_P(SingleDeviceShardingTest, IndexDomains) {
auto device_list = GetDevices({0});
std::shared_ptr<const Sharding> sharding =
Expand Down Expand Up @@ -92,6 +99,13 @@ TEST_P(SingleDeviceShardingTest, Disassemble) {
}
}

TEST_P(OpaqueShardingTest, IsFullyReplicated) {
auto device_list = GetDevices({0, 1});
std::shared_ptr<const Sharding> sharding =
OpaqueSharding::Create(device_list, MemoryKind());
EXPECT_FALSE(sharding->IsFullyReplicated());
}

TEST_P(OpaqueShardingTest, FailedToDisassemble) {
auto device_list = GetDevices({0, 1});
std::shared_ptr<const Sharding> sharding =
Expand Down Expand Up @@ -125,6 +139,17 @@ TEST_P(OpaqueShardingTest, IndexDomainsFails) {
HasSubstr("OpaqueSharding does not have index domain information")));
}

TEST_P(ConcreteShardingTest, IsFullyReplicated) {
auto device_list = GetDevices({0, 1});
std::vector<Shape> shard_shapes;
shard_shapes.reserve(2);
shard_shapes.push_back(Shape({10}));
shard_shapes.push_back(Shape({20}));
std::shared_ptr<const Sharding> sharding = ConcreteSharding::Create(
device_list, MemoryKind(), Shape({30}), shard_shapes);
EXPECT_FALSE(sharding->IsFullyReplicated());
}

TEST_P(ConcreteShardingTest, Disassemble) {
auto device_list = GetDevices({0, 1});
std::vector<Shape> shard_shapes;
Expand Down Expand Up @@ -205,10 +230,29 @@ TEST_P(ConcreteShardingTest, IndexDomainsFails) {
"domain information")));
}

TEST_P(ConcreteEvenShardingTest, IsFullyReplicated) {
auto device_list = GetDevices({0, 1});
{
// Fully replicated.
std::shared_ptr<const Sharding> sharding =
ConcreteEvenSharding::Create(device_list, MemoryKind(), Shape({30}),
Shape({15}), /*is_fully_replicated=*/true);
EXPECT_TRUE(sharding->IsFullyReplicated());
}
{
// Not fully replicated.
std::shared_ptr<const Sharding> sharding = ConcreteEvenSharding::Create(
device_list, MemoryKind(), Shape({30}), Shape({15}),
/*is_fully_replicated=*/false);
EXPECT_FALSE(sharding->IsFullyReplicated());
}
}

TEST_P(ConcreteEvenShardingTest, Disassemble) {
auto device_list = GetDevices({0, 1});
std::shared_ptr<const Sharding> sharding = ConcreteEvenSharding::Create(
device_list, MemoryKind(), Shape({30}), Shape({15}));
std::shared_ptr<const Sharding> sharding =
ConcreteEvenSharding::Create(device_list, MemoryKind(), Shape({30}),
Shape({15}), /*is_fully_replicated=*/false);

TF_ASSERT_OK_AND_ASSIGN(auto disassembled,
sharding->Disassemble(Shape({30})));
Expand All @@ -224,8 +268,9 @@ TEST_P(ConcreteEvenShardingTest, Disassemble) {

TEST_P(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) {
auto device_list = GetDevices({0, 1});
std::shared_ptr<const Sharding> sharding = ConcreteEvenSharding::Create(
device_list, MemoryKind(), Shape({30}), Shape({15}));
std::shared_ptr<const Sharding> sharding =
ConcreteEvenSharding::Create(device_list, MemoryKind(), Shape({30}),
Shape({15}), /*is_fully_replicated=*/false);

EXPECT_THAT(sharding->Disassemble(Shape({40})),
StatusIs(tsl::error::INVALID_ARGUMENT,
Expand All @@ -235,8 +280,9 @@ TEST_P(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) {
TEST_P(ConcreteEvenShardingTest, IndexDomainsFails) {
auto device_list = GetDevices({0, 1});
std::vector<Shape> shard_shapes;
std::shared_ptr<const Sharding> sharding = ConcreteEvenSharding::Create(
device_list, MemoryKind(), Shape({30}), Shape({15}));
std::shared_ptr<const Sharding> sharding =
ConcreteEvenSharding::Create(device_list, MemoryKind(), Shape({30}),
Shape({15}), /*is_fully_replicated=*/false);

EXPECT_THAT(
sharding->IndexDomains(Shape({30})),
Expand All @@ -257,6 +303,37 @@ TEST_P(ShardingParamShardingTest, CreateFailsWhenDeviceCountNotMatch) {
"ShardingParam 6 vs from DeviceList 2")));
}

TEST_P(ShardingParamShardingTest, IsFullyReplicated) {
auto device_list = GetDevices({0, 1, 2, 3, 4, 5});
{
// Fully replicated.
ShardingParam param{/*dim_shards=*/{1, 1},
{/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}};
TF_ASSERT_OK_AND_ASSIGN(
std::shared_ptr<const Sharding> param_sharding,
ShardingParamSharding::Create(param, device_list, MemoryKind()));
EXPECT_TRUE(param_sharding->IsFullyReplicated());
}
{
// Not fully replicated.
ShardingParam param{/*dim_shards=*/{1, 6},
{/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}};
TF_ASSERT_OK_AND_ASSIGN(
std::shared_ptr<const Sharding> param_sharding,
ShardingParamSharding::Create(param, device_list, MemoryKind()));
EXPECT_FALSE(param_sharding->IsFullyReplicated());
}
{
// Not fully replicated.
ShardingParam param{/*dim_shards=*/{2, 3},
{/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}};
TF_ASSERT_OK_AND_ASSIGN(
std::shared_ptr<const Sharding> param_sharding,
ShardingParamSharding::Create(param, device_list, MemoryKind()));
EXPECT_FALSE(param_sharding->IsFullyReplicated());
}
}

TEST_P(ShardingParamShardingTest, Disassemble) {
auto device_list = GetDevices({0, 1, 2, 3, 4, 5});
ShardingParam param{/*dim_shards=*/{2, 3},
Expand Down
Loading

0 comments on commit aed4341

Please sign in to comment.