Skip to content

Commit

Permalink
Added a virtual function (CanPropagateShardingToOperands) to `Custo…
Browse files Browse the repository at this point in the history
…mCallShardingHelper` that can be used to specify whether a custom call can propagate sharding from/to its operands.

PiperOrigin-RevId: 621915832
  • Loading branch information
tensorflower-gardener committed Apr 4, 2024
1 parent 8184dd9 commit a4d9df4
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 15 deletions.
Expand Up @@ -1456,8 +1456,9 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
HloInstruction* operand =
instructions.at(strategy_group->in_nodes.at(i)->instruction_id);
std::optional<HloSharding> input_sharding =
ShardingPropagation::GetShardingFromUser(*operand, *ins, 10,
true, call_graph);
ShardingPropagation::GetShardingFromUser(
*operand, *ins, 10, true, call_graph,
/*sharding_helper=*/nullptr);
StrategyGroup* operand_strategy_group =
strategy_map.at(operand).get();
Shape operand_shape = operand->shape();
Expand Down
Expand Up @@ -97,8 +97,9 @@ std::optional<HloSharding> GetInputSharding(const HloInstruction* ins,
}

std::optional<HloSharding> inferred_sharding =
ShardingPropagation::GetShardingFromUser(
*ins_clone->operand(op_index), *ins_clone, 10, true, call_graph);
ShardingPropagation::GetShardingFromUser(*ins_clone->operand(op_index),
*ins_clone, 10, true, call_graph,
/*sharding_helper=*/nullptr);

if (!inferred_sharding.has_value() && IsTopKCustomCall(ins)) {
// ShardingPropagation::GetShardingFromUser does not handle TopK custom
Expand Down
5 changes: 5 additions & 0 deletions third_party/xla/xla/service/custom_call_sharding_helper.cc
Expand Up @@ -35,6 +35,11 @@ bool CustomCallShardingHelper::IsCustomCallShardable(
return false;
}

bool CustomCallShardingHelper::CanPropagateShardingToOperands(
const HloInstruction* instruction) const {
return true;
}

xla::Status CustomCallPartitioner::Partition(
spmd::SpmdPartitioningVisitor* partitioner, HloInstruction* hlo) const {
return xla::Unimplemented("Implement sharding for %s", hlo->ToString());
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/service/custom_call_sharding_helper.h
Expand Up @@ -47,6 +47,10 @@ class CustomCallShardingHelper {
HloInstruction* instruction) const {
return {};
}
// Returns if the given custom-call instruction can propagate sharding to its
// operands.
virtual bool CanPropagateShardingToOperands(
const HloInstruction* instruction) const;
virtual ~CustomCallShardingHelper() = default;
};

Expand Down
34 changes: 28 additions & 6 deletions third_party/xla/xla/service/sharding_propagation.cc
Expand Up @@ -542,7 +542,8 @@ std::optional<HloSharding> LookaheadUserSharding(HloInstruction* instr,
HloInstruction* current = users_chain[i - 1];
CHECK(user->has_sharding());
sharding = ShardingPropagation::GetShardingFromUser(
*current, *user, INT64_MAX, is_spmd, call_graph);
*current, *user, INT64_MAX, is_spmd, call_graph,
/*sharding_helper=*/nullptr);
// We need to set the sharding to the instruction, because
// GetShardingFromUser() interface uses sharding from the instruction
// itself. It will be cleared out later.
Expand Down Expand Up @@ -1140,7 +1141,8 @@ bool InferUnspecifiedDimsFromOneUser(HloInstruction* annotate_op,
std::optional<HloSharding> user_sharding =
ShardingPropagation::GetShardingFromUser(
man_conversion_op == nullptr ? *annotate_op : *man_conversion_op,
*user, aggressiveness, is_spmd, call_graph);
*user, aggressiveness, is_spmd, call_graph,
/*sharding_helper=*/nullptr);
if (!user_sharding.has_value() || user_sharding->IsTileMaximal()) {
return false;
}
Expand Down Expand Up @@ -1720,7 +1722,8 @@ int64_t ComputeNonRootUsers(const HloInstruction* instr) {
// Return the sharding that should be propagated from user to instruction.
std::optional<HloSharding> ShardingPropagation::GetShardingFromUser(
const HloInstruction& instruction, const HloInstruction& user,
int64_t aggressiveness, bool is_spmd, const CallGraph& call_graph) {
int64_t aggressiveness, bool is_spmd, const CallGraph& call_graph,
const CustomCallShardingHelper* sharding_helper) {
if (!CanPropagateThroughAtAggressiveLevel(user, aggressiveness)) {
return std::nullopt;
}
Expand Down Expand Up @@ -2074,6 +2077,23 @@ std::optional<HloSharding> ShardingPropagation::GetShardingFromUser(
}
return std::nullopt;
}
case HloOpcode::kCustomCall: {
bool compatible_shapes = ShapeUtil::CompatibleIgnoringElementType(
instruction.shape(), user.shape());
if (!compatible_shapes) {
// Incompatible shapes, we will not propagate sharding.
return std::nullopt;
}
if (!sharding_helper) {
// No available sharding helper and shapes are compatible, we will
// propagate sharding.
return user.sharding();
}
if (sharding_helper->CanPropagateShardingToOperands(&user)) {
return user.sharding();
}
return std::nullopt;
}
default: {
// If the user output shape is compatible with the current instruction
// shape excluding element type and the current instruction is supported
Expand Down Expand Up @@ -2801,7 +2821,8 @@ bool ShardingPropagation::InferShardingFromUsers(
} else {
std::optional<HloSharding> user_sharding =
ShardingPropagation::GetShardingFromUser(
*instruction, *user, aggressiveness, is_spmd, call_graph);
*instruction, *user, aggressiveness, is_spmd, call_graph,
sharding_helper);
if (user_sharding && user_sharding->IsManual()) {
instruction->set_sharding(std::move(*user_sharding));
return true;
Expand All @@ -2820,8 +2841,9 @@ bool ShardingPropagation::InferShardingFromUsers(
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
for (const HloInstruction* user : instruction->users()) {
std::optional<HloSharding> user_sharding =
ShardingPropagation::GetShardingFromUser(
*instruction, *user, aggressiveness, is_spmd, call_graph);
ShardingPropagation::GetShardingFromUser(*instruction, *user,
aggressiveness, is_spmd,
call_graph, sharding_helper);
if (user_sharding && instruction->opcode() == HloOpcode::kCustomCall) {
if (auto* partitioner =
GetCustomCallPartitioner(instruction->custom_call_target())) {
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/sharding_propagation.h
Expand Up @@ -135,7 +135,8 @@ class ShardingPropagation : public HloModulePass {

static std::optional<HloSharding> GetShardingFromUser(
const HloInstruction& instruction, const HloInstruction& user,
int64_t aggressiveness, bool is_spmd, const CallGraph& call_graph);
int64_t aggressiveness, bool is_spmd, const CallGraph& call_graph,
const CustomCallShardingHelper* sharding_helper);

// Canonicalizes entry_computation_layouts by calling
// module.layout_canonicalization_callback(), which gives canolicalized
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/spmd/dot_handler.cc
Expand Up @@ -527,9 +527,9 @@ std::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
}
constexpr int kAggressiveness = 3;
std::optional<HloSharding> original_ideal_sharding =
ShardingPropagation::GetShardingFromUser(*to_loop_over, *original_hlo,
kAggressiveness,
/*is_spmd=*/true, call_graph);
ShardingPropagation::GetShardingFromUser(
*to_loop_over, *original_hlo, kAggressiveness,
/*is_spmd=*/true, call_graph, /*sharding_helper=*/nullptr);
// Default to perform collective matmul if GetShardingFromUser() couldn't
// determine the sharding.
if (!original_ideal_sharding) {
Expand All @@ -542,7 +542,7 @@ std::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
std::optional<HloSharding> from_user =
ShardingPropagation::GetShardingFromUser(
*to_loop_over, *user, kAggressiveness,
/*is_spmd=*/true, call_graph);
/*is_spmd=*/true, call_graph, /*sharding_helper=*/nullptr);
// Could't determine sharding. Skip to next one and pretend it wouldn't
// share the resharding.
if (!from_user) {
Expand Down

0 comments on commit a4d9df4

Please sign in to comment.