Skip to content

Commit

Permalink
Fix two issues that prevent tests in servable_lm_model_test.py from w…
Browse files Browse the repository at this point in the history
…orking with auto-sharding:

1. Reshard model vars based on shardings inferred by the auto-sharding pass. In PAX, we create model vars and initialize them based on the shardings inferred. In SAX, we need to re-shard them as they are created before we run auto-sharding (say when the model is loaded from a checkpoint).
2. Ensure that an empty resharding cost vector (which can arise when an arguments to an HLO is an empty tuple) is not considered as a tuple of infinite costs.

Also enable tests that have been fixed in previous CLs in this chain. A couple remain unfixed.

PiperOrigin-RevId: 535459455
  • Loading branch information
tensorflower-gardener committed May 26, 2023
1 parent 6760b19 commit b737261
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2852,6 +2852,7 @@ void SetHloShardingPostProcessing(const HloInstructionSequence& sequence,
if (inst->shape().IsTuple()) {
switch (inst->opcode()) {
case HloOpcode::kReduce:
case HloOpcode::kCustomCall:
case HloOpcode::kSort: {
for (size_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
const ShardingStrategy& stra = GetShardingStrategyForTuple(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,10 @@ bool AllInfinityCosts(
const std::vector<std::vector<double>>& resharding_costs) {
for (const auto& costs : resharding_costs) {
bool all_infinity = true;
if (costs.empty()) {
all_infinity = false;
continue;
}
for (const auto& cost : costs) {
if (cost < kInfinityCost) {
all_infinity = false;
Expand Down Expand Up @@ -952,7 +956,6 @@ void RemoveDuplicatedStrategy(std::unique_ptr<StrategyVector>& strategies) {
}
}


bool IsDivisible(const HloInstruction* ins, const Array<int64_t>& device_mesh,
absl::Span<const int64_t> tensor_dims,
absl::Span<const int64_t> mesh_dims) {
Expand Down

0 comments on commit b737261

Please sign in to comment.