Skip to content

Commit

Permalink
Add an option (set to false by default, for now) and support for addi…
Browse files Browse the repository at this point in the history
…ng strategies for dot operators that trigger windowed einsum.

PiperOrigin-RevId: 621681901
  • Loading branch information
tensorflower-gardener committed May 13, 2024
1 parent 68f1e14 commit e2aa3e6
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tensorflow/lite/g3doc/models/convert/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ for your model:
1. [Optimization flags](../../performance/model_optimization) allow you to
specify the type of optimization to apply
during conversion. The most commonly used optimization technique is
[post-training quanitization]().
[post-training quantization]().
1. [Metadata flags](metadata) allow you to add metadata to the converted model
which makes it easier to create platform specific wrapper code when deploying
models on devices.
Expand Down Expand Up @@ -142,7 +142,7 @@ format model and a custom runtime environment for that model.
converting your model.
* See the [optimization overview](../../performance/model_optimization) for
guidance on how to optimize your converted model using techniques like
[post-training quanitization](../../performance/post_training_quantization).
[post-training quantization](../../performance/post_training_quantization).
* See the [Adding metadata overview](metadata) to learn how to add metadata to
your models. Metadata provides other uses a description of your model as well
as information that can be leveraged by code generators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -87,6 +90,8 @@ class HandlerBase {
lhs_(ins->operand(0)),
rhs_(ins->operand(1)) {}

virtual ~HandlerBase() = default;

void AppendNewStrategy(const std::string& name,
const HloSharding& output_spec,
absl::Span<const HloSharding> input_specs,
Expand All @@ -113,6 +118,34 @@ class HandlerBase {
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn = std::nullopt);

// Given lhs and rhs dim maps, infers a sharding for the output by relying on
// the sharding_propagation pass.
void MaybeAppendInternal(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
const std::optional<DimMap>& expected_output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost = 0,
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn = std::nullopt);

// Given an existing (non-allreduce) sharding candidate, generate a
// corresponding candidate by additionally sharding (if possible) the passed
// in operand, such that, the generated candidate can trigger all-gather
// windowed einsum during partitioning.
virtual void AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) {}

// Given an existing (allreduce) sharding candidate, generate a corresponding
// candidate by additionally sharding (if possible) the dot/conv output, such
// that, the generated candidate can trigger reduce-scatter windowed einsum
// during partitioning.
virtual void AppendReduceScatterWindowedEinsumStrategy(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) {}

std::optional<HloSharding> GetShardingFromUser(const HloSharding& lhs_spec,
const HloSharding& rhs_spec);

Expand Down Expand Up @@ -177,6 +210,8 @@ class DotHandler : public HandlerBase {
const InstructionBatchDimMap& batch_map, const AutoShardingOption& option,
const CallGraph& call_graph);

~DotHandler() override = default;

void SplitLhsSpaceRhsSpace();

void SplitLhsSpaceOnly();
Expand Down Expand Up @@ -205,6 +240,16 @@ class DotHandler : public HandlerBase {

void Add1DBatchSplit();

void AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) override;

void AppendReduceScatterWindowedEinsumStrategy(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) override;

Status RegisterStrategies();

// Dimension information
Expand All @@ -228,6 +273,8 @@ class ConvHandler : public HandlerBase {
const InstructionBatchDimMap& batch_map,
const AutoShardingOption& option, const CallGraph& call_graph);

~ConvHandler() override = default;

void SplitLhsBatchRhsOutchannel();

void SplitLhsBatchBothInchannel();
Expand Down Expand Up @@ -287,7 +334,7 @@ void HandlerBase::AppendNewStrategy(const std::string& name,
// TODO(b/309638633) As we build more confidence in this, we should remove
// this expected_output_dim_map argument and fully rely on sharding
// propagation.
void HandlerBase::MaybeAppend(
void HandlerBase::MaybeAppendInternal(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
const std::optional<DimMap>& expected_output_dim_map,
Expand Down Expand Up @@ -336,6 +383,35 @@ void HandlerBase::MaybeAppend(
communication_cost);
}

void HandlerBase::MaybeAppend(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map,
const std::optional<DimMap>& expected_output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost,
const std::optional<std::function<double(const HloSharding&)>>&
communication_cost_fn) {
MaybeAppendInternal(name, lhs_dim_map, rhs_dim_map, expected_output_dim_map,
device_mesh, compute_cost, communication_cost_fn);
if (!option_.generate_windowed_einsum_strategies ||
!expected_output_dim_map.has_value()) {
return;
}
if (absl::StrContains(name, "allreduce")) {
CHECK(communication_cost_fn.has_value());
AppendReduceScatterWindowedEinsumStrategy(name, lhs_dim_map, rhs_dim_map,
*expected_output_dim_map,
device_mesh, compute_cost);
} else {
CHECK(!communication_cost_fn.has_value());
AppendAllGatherWindowedEinsumStrategyForOperand(
0, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map,
device_mesh, compute_cost);
AppendAllGatherWindowedEinsumStrategyForOperand(
1, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map,
device_mesh, compute_cost);
}
}

std::optional<HloSharding> HandlerBase::GetShardingFromUser(
const HloSharding& lhs_spec, const HloSharding& rhs_spec) {
std::unique_ptr<HloInstruction> ins_clone = ins_->Clone();
Expand Down Expand Up @@ -771,6 +847,108 @@ void DotHandler::Add1DBatchSplit() {
}
}

void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand(
int operand_num, const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) {
const HloInstruction* operand = ins_->operand(operand_num);
const DimMap& operand_dim_map = operand_num == 0 ? lhs_dim_map : rhs_dim_map;
absl::flat_hash_set<int64_t> sharded_tensor_dims;
absl::flat_hash_set<int64_t> used_mesh_dims;
for (const auto [tensor_dim, mesh_dim] : operand_dim_map) {
if (device_mesh.dim(mesh_dim) == 1) {
continue;
}
sharded_tensor_dims.insert(tensor_dim);
used_mesh_dims.insert(mesh_dim);
}
if (used_mesh_dims.size() == device_mesh_.num_dimensions() ||
sharded_tensor_dims.size() == operand->shape().rank()) {
return;
}

for (int64_t tensor_dim = 0; tensor_dim < operand->shape().rank();
++tensor_dim) {
if (sharded_tensor_dims.contains(tensor_dim)) {
continue;
}
for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
++mesh_dim) {
if (used_mesh_dims.contains(mesh_dim) ||
(device_mesh.dim(mesh_dim) == 1)) {
continue;
}
DimMap further_sharded_dim_map = operand_dim_map;
further_sharded_dim_map[tensor_dim] = mesh_dim;

auto updated_communication_cost_fn =
[](const HloSharding& output_sharding) -> double {
// TODO(331684721): Model costs for windowed einsum
return 100.0;
};

std::string updated_name =
absl::StrCat(absl::StrFormat("WindowedEinsum @ {%d,%d,%d}",
operand_num, tensor_dim, mesh_dim),
name);
MaybeAppendInternal(
updated_name,
operand_num == 0 ? further_sharded_dim_map : lhs_dim_map,
operand_num == 1 ? further_sharded_dim_map : rhs_dim_map,
output_dim_map, device_mesh, compute_cost,
updated_communication_cost_fn);
}
}
}

void DotHandler::AppendReduceScatterWindowedEinsumStrategy(
const std::string& name, const DimMap& lhs_dim_map,
const DimMap& rhs_dim_map, const DimMap& output_dim_map,
const Array<int64_t>& device_mesh, double compute_cost) {
absl::flat_hash_set<int64_t> sharded_tensor_dims;
absl::flat_hash_set<int64_t> used_mesh_dims;
for (const auto [tensor_dim, mesh_dim] : output_dim_map) {
if (device_mesh.dim(mesh_dim) == 1) {
continue;
}
sharded_tensor_dims.insert(tensor_dim);
used_mesh_dims.insert(mesh_dim);
}
if (used_mesh_dims.size() == device_mesh_.num_dimensions() ||
sharded_tensor_dims.size() == ins_->shape().rank()) {
return;
}

for (int64_t tensor_dim = 0; tensor_dim < ins_->shape().rank();
++tensor_dim) {
if (sharded_tensor_dims.contains(tensor_dim)) {
continue;
}
for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions();
++mesh_dim) {
if (used_mesh_dims.contains(mesh_dim) ||
(device_mesh.dim(mesh_dim) == 1)) {
continue;
}
DimMap further_sharded_dim_map = output_dim_map;
further_sharded_dim_map[tensor_dim] = mesh_dim;

auto updated_communication_cost_fn =
[](const HloSharding& output_sharding) -> double {
// TODO(331684721): Model costs for windowed einsum
return 100.0;
};

std::string updated_name = absl::StrCat(
absl::StrFormat("WindowedEinsum @ {%d,%d}", tensor_dim, mesh_dim),
name);
MaybeAppendInternal(updated_name, lhs_dim_map, rhs_dim_map,
further_sharded_dim_map, device_mesh, compute_cost,
updated_communication_cost_fn);
}
}
}

Status DotHandler::RegisterStrategies() {
// SS = SR x RS
// Split lhs space dim and rhs space dim.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ std::string AutoShardingOption::ToString() const {
lines.push_back(absl::StrCat("model_resharding_memory_costs: ",
model_resharding_memory_costs));

lines.push_back(absl::StrCat("generate_windowed_einsum_strategies: ",
generate_windowed_einsum_strategies));

return absl::StrJoin(lines, "\n");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ struct AutoShardingOption {
// for resharding edges.
bool model_resharding_memory_costs = true;

// Whether or not to generate strategies that model the windowed einsum (or
// collective matmul) optimization
// TODO(331684721,329508561): Generate windowed-einsum strategies by default
// once it is fully implemented.
bool generate_windowed_einsum_strategies = false;

// Prints a debug string.
std::string ToString() const;

Expand Down

0 comments on commit e2aa3e6

Please sign in to comment.