Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option (set to false by default, for now) and support for adding strategies for dot operators that trigger windowed einsum. #67369

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -87,6 +90,8 @@ class HandlerBase {
lhs_(ins->operand(0)),
rhs_(ins->operand(1)) {}

virtual ~HandlerBase() = default;

void AppendNewStrategy(const std::string& name,
const HloSharding& output_spec,
absl::Span<const HloSharding> input_specs,
Expand All @@ -113,6 +118,34 @@ class HandlerBase {
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn = std::nullopt);

// Given lhs and rhs dim maps, infers a sharding for the output by relying on
// the sharding_propagation pass.
void MaybeAppendInternal(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
const std::optional<DimMap>& expected_output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost = 0,
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn = std::nullopt);

// Given an existing (non-allreduce) sharding candidate, generate a
// corresponding candidate by additionally sharding (if possible) the passed
// in operand, such that, the generated candidate can trigger all-gather
// windowed einsum during partitioning.
virtual void AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) {}

// Given an existing (allreduce) sharding candidate, generate a corresponding
// candidate by additionally sharding (if possible) the dot/conv output, such
// that, the generated candidate can trigger reduce-scatter windowed einsum
// during partitioning.
virtual void AppendReduceScatterWindowedEinsumStrategy(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) {}

std::optional<HloSharding> GetShardingFromUser(const HloSharding& lhs_spec,
const HloSharding& rhs_spec);

Expand Down Expand Up @@ -177,6 +210,8 @@ class DotHandler : public HandlerBase {
const InstructionBatchDimMap& batch_map, const AutoShardingOption& option,
const CallGraph& call_graph);

~DotHandler() override = default;

void SplitLhsSpaceRhsSpace();

void SplitLhsSpaceOnly();
Expand Down Expand Up @@ -205,6 +240,16 @@ class DotHandler : public HandlerBase {

void Add1DBatchSplit();

void AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) override;

void AppendReduceScatterWindowedEinsumStrategy(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) override;

Status RegisterStrategies();

// Dimension information
Expand All @@ -228,6 +273,8 @@ class ConvHandler : public HandlerBase {
const InstructionBatchDimMap& batch_map,
const AutoShardingOption& option, const CallGraph& call_graph);

~ConvHandler() override = default;

void SplitLhsBatchRhsOutchannel();

void SplitLhsBatchBothInchannel();
Expand Down Expand Up @@ -287,7 +334,7 @@ void HandlerBase::AppendNewStrategy(const std::string& name,
// TODO(b/309638633) As we build more confidence in this, we should remove
// this expected_output_dim_map argument and fully rely on sharding
// propagation.
void HandlerBase::MaybeAppend(
void HandlerBase::MaybeAppendInternal(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
const std::optional<DimMap>& expected_output_dim_map,
Expand Down Expand Up @@ -336,6 +383,35 @@ void HandlerBase::MaybeAppend(
communication_cost);
}

void HandlerBase::MaybeAppend(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
const std::optional<DimMap>& expected_output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost,
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn) {
MaybeAppendInternal(name, lhs_dim_map, rhs_dim_map, expected_output_dim_map,
device_mesh, compute_cost, communication_cost_fn);
if (!option_.generate_windowed_einsum_strategies ||
!expected_output_dim_map.has_value()) {
return;
}
if (absl::StrContains(name, "allreduce")) {
CHECK(communication_cost_fn.has_value());
AppendReduceScatterWindowedEinsumStrategy(name, lhs_dim_map, rhs_dim_map,
*expected_output_dim_map,
device_mesh, compute_cost);
} else {
CHECK(!communication_cost_fn.has_value());
AppendAllGatherWindowedEinsumStrategyForOperand(
0, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map,
device_mesh, compute_cost);
AppendAllGatherWindowedEinsumStrategyForOperand(
1, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map,
device_mesh, compute_cost);
}
}

std::optional<HloSharding> HandlerBase::GetShardingFromUser(
const HloSharding& lhs_spec, const HloSharding& rhs_spec) {
std::unique_ptr<HloInstruction> ins_clone = ins_->Clone();
Expand Down Expand Up @@ -771,6 +847,108 @@ void DotHandler::Add1DBatchSplit() {
}
}

void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) {
const HloInstruction* operand = ins_->operand(operand_num);
const DimMap& operand_dim_map = operand_num == 0 ? lhs_dim_map : rhs_dim_map;
absl::flat_hash_set<int64_t> sharded_tensor_dims;
absl::flat_hash_set<int64_t> used_mesh_dims;
for (const auto [tensor_dim, mesh_dim] : operand_dim_map) {
if (device_mesh.dim(mesh_dim) == 1) {
continue;
}
sharded_tensor_dims.insert(tensor_dim);
used_mesh_dims.insert(mesh_dim);
}
if (used_mesh_dims.size() == device_mesh_.num_dimensions() ||
sharded_tensor_dims.size() == operand->shape().rank()) {
return;
}

for (int64_t tensor_dim = 0; tensor_dim < operand->shape().rank();
++tensor_dim) {
if (sharded_tensor_dims.contains(tensor_dim)) {
continue;
}
for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
++mesh_dim) {
if (used_mesh_dims.contains(mesh_dim) ||
(device_mesh.dim(mesh_dim) == 1)) {
continue;
}
DimMap further_sharded_dim_map = operand_dim_map;
further_sharded_dim_map[tensor_dim] = mesh_dim;

auto updated_communication_cost_fn =
[](const HloSharding& output_sharding) -> double {
// TODO(331684721): Model costs for windowed einsum
return 100.0;
};

std::string updated_name =
absl::StrCat(absl::StrFormat("WindowedEinsum @ {%d,%d,%d}",
operand_num, tensor_dim, mesh_dim),
name);
MaybeAppendInternal(
updated_name,
operand_num == 0 ? further_sharded_dim_map : lhs_dim_map,
operand_num == 1 ? further_sharded_dim_map : rhs_dim_map,
output_dim_map, device_mesh, compute_cost,
updated_communication_cost_fn);
}
}
}

void DotHandler::AppendReduceScatterWindowedEinsumStrategy(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) {
absl::flat_hash_set<int64_t> sharded_tensor_dims;
absl::flat_hash_set<int64_t> used_mesh_dims;
for (const auto [tensor_dim, mesh_dim] : output_dim_map) {
if (device_mesh.dim(mesh_dim) == 1) {
continue;
}
sharded_tensor_dims.insert(tensor_dim);
used_mesh_dims.insert(mesh_dim);
}
if (used_mesh_dims.size() == device_mesh_.num_dimensions() ||
sharded_tensor_dims.size() == ins_->shape().rank()) {
return;
}

for (int64_t tensor_dim = 0; tensor_dim < ins_->shape().rank();
++tensor_dim) {
if (sharded_tensor_dims.contains(tensor_dim)) {
continue;
}
for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
++mesh_dim) {
if (used_mesh_dims.contains(mesh_dim) ||
(device_mesh.dim(mesh_dim) == 1)) {
continue;
}
DimMap further_sharded_dim_map = output_dim_map;
further_sharded_dim_map[tensor_dim] = mesh_dim;

auto updated_communication_cost_fn =
[](const HloSharding& output_sharding) -> double {
// TODO(331684721): Model costs for windowed einsum
return 100.0;
};

std::string updated_name = absl::StrCat(
absl::StrFormat("WindowedEinsum @ {%d,%d}", tensor_dim, mesh_dim),
name);
MaybeAppendInternal(updated_name, lhs_dim_map, rhs_dim_map,
further_sharded_dim_map, device_mesh, compute_cost,
updated_communication_cost_fn);
}
}
}

Status DotHandler::RegisterStrategies() {
// SS = SR x RS
// Split lhs space dim and rhs space dim.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ std::string AutoShardingOption::ToString() const {
lines.push_back(absl::StrCat("model_resharding_memory_costs: ",
model_resharding_memory_costs));

lines.push_back(absl::StrCat("generate_windowed_einsum_strategies: ",
generate_windowed_einsum_strategies));

return absl::StrJoin(lines, "\n");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ struct AutoShardingOption {
// for resharding edges.
bool model_resharding_memory_costs = true;

// Whether or not to generate strategies that model the windowed einsum (or
// collective matmul) optimization
// TODO(331684721,329508561): Generate windowed-einsum strategies by default
// once it is fully implemented.
bool generate_windowed_einsum_strategies = false;

// Prints a debug string.
std::string ToString() const;

Expand Down