Skip to content

Commit

Permalink
Reverts f1bac5e
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628097889
  • Loading branch information
tensorflower-gardener committed Apr 25, 2024
1 parent 70a5f10 commit 5445d4b
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 128 deletions.
8 changes: 0 additions & 8 deletions third_party/xla/workspace2.bzl
Expand Up @@ -36,14 +36,6 @@ def _tf_repositories():
# curl -L <url> | sha256sum
# and update the sha256 with the result.

tf_http_archive(
name = "jsoncpp_git",
sha256 = "f409856e5920c18d0c2fb85276e24ee607d2a09b5e7d5f0a371368903c275da2",
strip_prefix = "jsoncpp-1.9.5",
system_build_file = "//third_party/systemlibs:jsoncpp.BUILD",
urls = tf_mirror_urls("https://github.com/open-source-parsers/jsoncpp/archive/1.9.5.tar.gz"),
)

tf_http_archive(
name = "cudnn_frontend_archive",
build_file = "//third_party:cudnn_frontend.BUILD",
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/hlo/experimental/auto_sharding/BUILD
Expand Up @@ -275,7 +275,6 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@jsoncpp_git//:jsoncpp",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status",
],
Expand Down
Expand Up @@ -362,71 +362,6 @@ void FollowArrayOrTokenStrategyGroup(
}
}

std::unique_ptr<StrategyGroup> HandlePartialReduce(
const HloInstruction* ins, const size_t instruction_id,
const bool have_memory_cost, StrategyGroups& strategy_groups,
const ClusterEnvironment& cluster_env, StrategyMap& strategy_map,
const CallGraph& call_graph) {
absl::StatusOr<int64_t> reduction_dim = GetPartialReduceReductionDim(ins);
CHECK_OK(reduction_dim);
const Shape& shape = ins->shape();
const HloInstruction* operand = ins->operand(0);
const StrategyGroup* src_strategy_group = strategy_map.at(operand).get();

std::unique_ptr<StrategyGroup> strategy_group =
CreateTupleStrategyGroup(instruction_id);
int64_t output_size = shape.tuple_shapes_size();
for (size_t i = 0; i < output_size; ++i) {
std::unique_ptr<StrategyGroup> child_strategy_group =
CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups);
child_strategy_group->in_nodes.push_back(src_strategy_group);
child_strategy_group->following = src_strategy_group;
for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) {
const HloSharding& input_spec =
src_strategy_group->strategies[sid].output_sharding;
// There is no way for us to handle manual sharding.
if (input_spec.IsManual() || input_spec.IsManualSubgroup()) {
continue;
}

HloSharding output_spec = input_spec;
if (!(input_spec.IsReplicated() || input_spec.IsTileMaximal())) {
// All 3. sub-cases (reduction dim would be replicated in the
// output)
output_spec = hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
input_spec, {*reduction_dim});
}

// Get a list of input shardings, each corresponds to an operand.
std::vector<std::optional<HloSharding>> input_shardings;
for (int64_t k = 0; k < output_size * 2; ++k) {
if (k < output_size) {
input_shardings.push_back(input_spec);
} else {
input_shardings.push_back(HloSharding::Replicate());
}
}

std::string name = ToStringSimple(output_spec);
double compute_cost = 0, communication_cost = 0;
double memory_cost =
GetBytes(ins->shape().tuple_shapes(i)) / output_spec.NumTiles();
std::pair<ReshardingCosts, ReshardingCosts> resharding_costs =
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, output_spec, strategy_map, cluster_env, call_graph,
input_shardings);

child_strategy_group->strategies.push_back(ShardingStrategy(
{std::move(name), std::move(output_spec), compute_cost,
communication_cost, memory_cost, std::move(resharding_costs.first),
std::move(resharding_costs.second), std::move(input_shardings)}));
}

strategy_group->childs.push_back(std::move(child_strategy_group));
}
return strategy_group;
}

std::unique_ptr<StrategyGroup> MaybeFollowInsStrategyGroup(
const StrategyGroup* src_strategy_group, const Shape& shape,
const size_t instruction_id, const bool have_memory_cost,
Expand Down
Expand Up @@ -283,11 +283,6 @@ std::unique_ptr<StrategyGroup> CreateElementwiseOperatorStrategies(
int64_t max_depth, StrategyGroups& strategy_groups,
AssociativeDotPairs& associative_dot_pairs);

std::unique_ptr<StrategyGroup> HandlePartialReduce(
const HloInstruction* ins, size_t instruction_id, bool have_memory_cost,
StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env,
StrategyMap& strategy_map, const CallGraph& call_graph);

// Factory functions for StrategyGroup.
std::unique_ptr<StrategyGroup> CreateLeafStrategyGroupWithoutInNodes(
size_t instruction_id, StrategyGroups& strategy_groups);
Expand Down
Expand Up @@ -702,12 +702,8 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence,
src_strategy_group, ins->shape(), instruction_id,
/* have_memory_cost= */ true, strategy_groups, cluster_env,
pretrimmed_strategy_map);
} else if (IsTopKCustomCall(ins)) {
generate_non_following_strategies(false, {0});
} else if (IsPartialReduceCustomCall(ins)) {
strategy_group = HandlePartialReduce(
ins, instruction_id, /* have_memory_cost */ true, strategy_groups,
cluster_env, strategy_map, call_graph);
} else if (ins->has_sharding()) {
generate_non_following_strategies(false);
} else if (OutputInputSameShapes(ins)) {
auto* partitioner =
GetCustomCallPartitioner(ins->custom_call_target());
Expand All @@ -722,11 +718,11 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence,
/* have_memory_cost= */ true, strategy_groups, cluster_env,
pretrimmed_strategy_map);
}
} else if (ins->has_sharding()) {
generate_non_following_strategies(false);
} else if (IsTopKCustomCall(ins)) {
generate_non_following_strategies(false, {0});
} else {
// TODO (b/258723035) Handle CustomCall ops for GPUs in a better way.
generate_non_following_strategies(false);
generate_non_following_strategies(true);
}
break;
}
Expand Down
Expand Up @@ -41,7 +41,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "json/json.h"
#include "xla/array.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand Down Expand Up @@ -2342,35 +2341,5 @@ HloSharding ReplaceGivenShardingsWithUnknownForTuple(
return HloSharding::Tuple(shape, new_tuple_shardings);
}

absl::StatusOr<int64_t> GetPartialReduceReductionDim(
const HloInstruction* ins) {
constexpr char kReductionDimKey[] = "reduction_dim";
if (ins->raw_backend_config_string().empty()) {
return absl::InternalError(
"No backend config for a PartialReduce custom call.");
}
Json::Value parsed_json;
Json::Reader json_reader;
json_reader.parse(ins->raw_backend_config_string(), parsed_json,
/* collectComments */ false);
if (!parsed_json.isObject()) {
return absl::InternalError(
"Error when parsing json backend config for a PartialReduce custom "
"call.");
}
if (!parsed_json.isMember(kReductionDimKey)) {
return absl::InternalError(
"No backend config found for a PartialReduce custom call.");
}

if (!parsed_json[kReductionDimKey].isInt64()) {
return absl::InternalError(
"Error when extracting the reduction key from the json backend config "
"of a PartialReduce custom call.");
}

return parsed_json[kReductionDimKey].asInt64();
}

} // namespace spmd
} // namespace xla
Expand Up @@ -233,12 +233,6 @@ inline bool IsTopKCustomCall(const HloInstruction* inst) {
inst->custom_call_target() == "TopK";
}

// Return whether this instruction is a TopK custom call.
inline bool IsPartialReduceCustomCall(const HloInstruction* inst) {
return inst->opcode() == HloOpcode::kCustomCall &&
inst->custom_call_target() == "PartialReduce";
}

// Pass through the custom call marker and get the source instruction
inline const HloInstruction* PassThroughCustomCallMarkerGetSource(
const HloInstruction* ins) {
Expand Down Expand Up @@ -668,9 +662,6 @@ HloSharding ReplaceGivenShardingsWithUnknownForTuple(
const HloSharding& sharding, const Shape& shape,
absl::Span<const bool> to_replace_sharding_ids);

// Extract the reduction_dim of a PartialReduce custom call
absl::StatusOr<int64_t> GetPartialReduceReductionDim(const HloInstruction* ins);

} // namespace spmd
} // namespace xla

Expand Down

0 comments on commit 5445d4b

Please sign in to comment.