Skip to content

Commit

Permalink
Add GetLevelForDuration() helper function.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 641084238
  • Loading branch information
tensorflower-gardener committed Jun 11, 2024
1 parent 5636dc7 commit 7974fd4
Show file tree
Hide file tree
Showing 20 changed files with 143 additions and 21 deletions.
1 change: 1 addition & 0 deletions ci/official/requirements_updater/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ gast == 0.4.0
termcolor == 2.3.0
wrapt == 1.16.0
tblib == 2.0.0
ml_dtypes >= 0.4.0, < 0.5.0

# Install tensorboard, and keras
# Note that here we want the latest version that matches TF major.minor version
Expand Down
1 change: 1 addition & 0 deletions requirements_lock_3_10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ ml-dtypes==0.4.0 \
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
# via
# -r requirements.in
# jax
# keras-nightly
namex==0.0.8 \
Expand Down
1 change: 1 addition & 0 deletions requirements_lock_3_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ ml-dtypes==0.4.0 \
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
# via
# -r requirements.in
# jax
# keras-nightly
namex==0.0.8 \
Expand Down
1 change: 1 addition & 0 deletions requirements_lock_3_12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ ml-dtypes==0.4.0 \
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
# via
# -r requirements.in
# jax
# keras-nightly
namex==0.0.8 \
Expand Down
1 change: 1 addition & 0 deletions requirements_lock_3_9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ ml-dtypes==0.4.0 \
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
# via
# -r requirements.in
# jax
# keras-nightly
namex==0.0.8 \
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/tensorflow/translate/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/graph/regularization:util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
Expand All @@ -238,13 +239,12 @@ cc_library(
":translate_cl_options",
":translate_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TranslateLib",
"@local_tsl//tsl/platform:protobuf",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_

#include "absl/base/attributes.h"
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
Expand All @@ -29,6 +30,8 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"

namespace tensorflow {

ABSL_DEPRECATED("Use tensorflow::tf2xla::api::ConvertMlirToGraphdef instead.")
// Given an MLIR module, returns a GraphDef.
absl::StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
mlir::ModuleOp module, const GraphExportConfig& configs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "xla/client/client_library.h"
#include "xla/client/compile_only_client.h"
Expand Down Expand Up @@ -173,7 +174,7 @@ static LogicalResult MlirToGraphdefTranslateFunction(
confs.export_original_tf_func_name = export_original_tf_func_name;

absl::StatusOr<std::unique_ptr<tensorflow::GraphDef>> graphdef_or(
tensorflow::ConvertMlirToGraphdef(module, confs));
tensorflow::tf2xla::v2::ConvertMlirToGraphdef(module, confs));
if (!graphdef_or.status().ok()) {
LOG(ERROR) << "Graph export failed: " << graphdef_or.status();
return mlir::failure();
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/core/profiler/convert/trace_viewer/trace_events.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ void MaybeAddEventUniqueId(std::vector<TraceEvent*>& events) {

} // namespace

int GetLevelForDuration(uint64_t duration_ps) {
int i = 0;
for (; i < NumLevels(); ++i) {
if (duration_ps > kLayerResolutions[i]) {
return i;
}
}
return i;
}

std::vector<TraceEvent*> MergeEventTracks(
const std::vector<const TraceEventTrack*>& event_tracks) {
std::vector<TraceEvent*> events;
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/profiler/convert/trace_viewer/trace_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ absl::Status ReadFileTraceMetadata(std::string& filepath, Trace* trace);
std::vector<std::vector<const TraceEvent*>> GetEventsByLevel(
const Trace& trace, std::vector<TraceEvent*>& events);

// Returns the level that an event with `duration_ps` would go into.
int GetLevelForDuration(uint64_t duration_ps);

struct EventFactory {
TraceEvent* Create() {
events.push_back(std::make_unique<TraceEvent>());
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/pip_package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def standard_or_nightly(standard, nightly):
'google_pasta >= 0.1.1',
'h5py >= 3.10.0',
'libclang >= 13.0.0',
'ml_dtypes ~= 0.3.1',
'ml_dtypes >= 0.3.1, < 0.5.0',
# TODO(b/304751256): Adjust the numpy pin to a single version, when ready
'numpy >= 1.23.5, < 2.0.0 ; python_version <= "3.11"',
'numpy >= 1.26.0, < 2.0.0 ; python_version >= "3.12"',
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/hlo/ir/hlo_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ HloModuleProto HloModule::ToProto() const {
return proto;
}

absl::StatusOr<HloModuleProtoWithConfig> HloModule::ToProtoWithConfig() const {
HloModuleProtoWithConfig HloModule::ToProtoWithConfig() const {
HloModuleProtoWithConfig result;
*result.mutable_config() = config_.get().ToProto();
*result.mutable_hlo_module() = ToProto();
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/hlo/ir/hlo_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class HloModule {
bool prohibit_empty_literal = true);

// Convert an HloModule to or from a proto that includes module configuration
absl::StatusOr<HloModuleProtoWithConfig> ToProtoWithConfig() const;
HloModuleProtoWithConfig ToProtoWithConfig() const;
static absl::StatusOr<std::unique_ptr<HloModule>> CreateFromProtoWithConfig(
const HloModuleProtoWithConfig& proto,
bool prohibit_empty_literal = true);
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1141,8 +1141,7 @@ PJRT_Error* PJRT_Executable_OptimizedProgram(
program->format_size = kHloWithConfigFormat.size();
PJRT_ASSIGN_OR_RETURN(std::shared_ptr<xla::HloModule> hlo_module,
GetOptimizedProgramModule(args));
PJRT_ASSIGN_OR_RETURN(xla::HloModuleProtoWithConfig proto,
hlo_module->ToProtoWithConfig());
xla::HloModuleProtoWithConfig proto = hlo_module->ToProtoWithConfig();
if (program->code == nullptr) {
program->code_size = proto.ByteSizeLong();
if (program->code_size >= 2ull * 1024 * 1024 * 1024) {
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ xla_cc_test(
deps = [
":all_reduce_splitter",
":hlo_module_config",
":hlo_pass_pipeline",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:gpu_reduce_scatter_creator",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/algorithm:container",
Expand Down
103 changes: 103 additions & 0 deletions third_party/xla/xla/service/all_reduce_splitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/gpu_reduce_scatter_creator.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/lib/core/status_test_util.h"
Expand Down Expand Up @@ -398,6 +400,107 @@ ENTRY main {
EXPECT_EQ(AllReduceCount(*module), 2);
}

TEST_F(
AllReduceSplitterFilecheckTest,
PipelineMatchesBasicPatternWithDynamicSliceAsRootAndRewritesToReduceScatter) { // NOLINT
absl::string_view hlo_string = R"(
HloModule m
sum {
a = bf16[] parameter(0)
b = bf16[] parameter(1)
ROOT _ = bf16[] add(a,b)
}
ENTRY main {
p = bf16[2,4096,4096] parameter(0)
first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
zero = bf16[] constant(0)
reduce = bf16[4096] reduce(first.ar, zero), dimensions={0,1}, to_apply=sum
all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=2
table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
pid = u32[] partition-id()
id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
reshape = s32[] reshape(id)
slice_size = s32[] constant(1024)
offset = s32[] multiply(reshape, slice_size)
ROOT _ = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
}
)";

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));

HloPassPipeline pipeline("all-reduce-splitter-rewrite");
pipeline.AddPass<AllReduceSplitter>();
pipeline.AddPass<ReduceScatterCreator>();
EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
TF_EXPECT_OK(FileCheck(module->ToString(), R"(
CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]}
CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]])
CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
CHECK-SAME: replica_groups={[[DESIRED_RGS]]}
CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
)"));
}

TEST_F(
AllReduceSplitterFilecheckTest,
PipelineMatchesBasicPatternWithDynamicSliceNotAsRootAndRewritesToReduceScatter) { // NOLINT
absl::string_view hlo_string = R"(
HloModule m
sum {
a = bf16[] parameter(0)
b = bf16[] parameter(1)
ROOT _ = bf16[] add(a,b)
}
ENTRY main {
p = bf16[2,4096,4096] parameter(0)
zero = bf16[] constant(0)
first.ar = bf16[2,4096,4096] all-reduce(p), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
reduce = bf16[4096] reduce(p, zero), dimensions={0,1}, to_apply=sum
all-reduce = bf16[4096] all-reduce(reduce), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=sum, use_global_device_ids=true, channel_id=1
table = s32[8]{0} constant({0,1,2,3,0,1,2,3})
pid = u32[] partition-id()
id = s32[1] dynamic-slice(table, pid), dynamic_slice_sizes={1}
reshape = s32[] reshape(id)
slice_size = s32[] constant(1024)
offset = s32[] multiply(reshape, slice_size)
dynamic_slice = bf16[1024] dynamic-slice(all-reduce, offset), dynamic_slice_sizes={1024}
broadcast = bf16[1024,1024] broadcast(dynamic_slice), dimensions={0}
ROOT _ = tuple(broadcast, first.ar)
}
)";

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
PrepareModule(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8));

HloPassPipeline pipeline("all-reduce-splitter-rewrite");
pipeline.AddPass<AllReduceSplitter>();
pipeline.AddPass<ReduceScatterCreator>();
EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true));
TF_EXPECT_OK(FileCheck(module->ToString(), R"(
CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0)
CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0)
CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]])
CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]])
CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]])
CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}}
CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]])
CHECK: ROOT
CHECK-NOT: %[[AR1]]
CHECK-SAME: %[[EXISTING_AR]]
)"));
}

} // namespace
} // namespace gpu
} // namespace xla
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,7 @@ class GpuThunkAotCompilationResult : public AotCompilationResult {
std::string_view asm_text, absl::Span<const uint8_t> binary,
const Thunk::BinaryMap& dnn_compiled_graphs) {
CompilationResultProto proto;
TF_ASSIGN_OR_RETURN(*proto.mutable_hlo_module_with_config(),
hlo_module->ToProtoWithConfig());
*proto.mutable_hlo_module_with_config() = hlo_module->ToProtoWithConfig();
*proto.mutable_buffer_assignment() = buffer_assignment->ToProto();
proto.set_asm_text(std::string(asm_text));
proto.set_binary(binary.data(), binary.size());
Expand Down
8 changes: 3 additions & 5 deletions third_party/xla/xla/service/hlo_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,7 @@ ENTRY ReduceR3ToR2.v3 {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(computation_text));

TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleProtoWithConfig proto,
module->ToProtoWithConfig());
xla::HloModuleProtoWithConfig proto = module->ToProtoWithConfig();
std::string serialized_module;
ASSERT_TRUE(tsl::SerializeToStringDeterministic(proto, &serialized_module));
std::string original_debug_str = proto.DebugString();
Expand All @@ -756,9 +755,8 @@ ENTRY ReduceR3ToR2.v3 {
// Verify that we can create a module from our parsed proto copy
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> reconstructed_module,
HloModule::CreateFromProtoWithConfig(proto));
TF_ASSERT_OK_AND_ASSIGN(
xla::HloModuleProtoWithConfig reconstructed_module_proto,
reconstructed_module->ToProtoWithConfig());
xla::HloModuleProtoWithConfig reconstructed_module_proto =
reconstructed_module->ToProtoWithConfig();

// The two protos should be equivalent except for the `id` field
google::protobuf::util::MessageDifferencer diff;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,9 @@ TEST(XlaHloModule, ToAndFromC) {
ASSERT_TRUE(out_module_ptr.ok());
xla::HloModule& out_module = *out_module_ptr.value();

TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleProtoWithConfig in_module_proto,
in_module.ToProtoWithConfig());
TF_ASSERT_OK_AND_ASSIGN(xla::HloModuleProtoWithConfig out_module_proto,
out_module.ToProtoWithConfig());
xla::HloModuleProtoWithConfig in_module_proto = in_module.ToProtoWithConfig();
xla::HloModuleProtoWithConfig out_module_proto =
out_module.ToProtoWithConfig();

tsl::protobuf::util::MessageDifferencer diff;
diff.set_message_field_comparison(
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/tests/local_client_execute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ XLA_TEST_F(LocalClientExecuteTest, ValidateFDOProfile) {
const HloModule& compiled_module =
executables.front()->executable()->module();
EXPECT_EQ(compiled_module.config().fdo_profile(), kFdoProfile);
TF_ASSERT_OK_AND_ASSIGN(auto proto, compiled_module.ToProtoWithConfig());
auto proto = compiled_module.ToProtoWithConfig();
EXPECT_EQ(proto.config().fdo_profile(), kFdoProfile);
}

Expand All @@ -991,7 +991,7 @@ XLA_TEST_F(LocalClientExecuteTest, ValidateDeviceMemorySize) {
const HloModule& compiled_module =
executables.front()->executable()->module();
EXPECT_EQ(compiled_module.config().device_memory_size(), kDeviceMemorySize);
TF_ASSERT_OK_AND_ASSIGN(auto proto, compiled_module.ToProtoWithConfig());
auto proto = compiled_module.ToProtoWithConfig();
EXPECT_EQ(proto.config().device_memory_size(), kDeviceMemorySize);
}

Expand Down

0 comments on commit 7974fd4

Please sign in to comment.