Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/xla_graph_executor.h"

namespace torch_xla {
Expand Down Expand Up @@ -431,9 +432,20 @@ void DebugUtil::post_compilation_analysis(
return;
}

std::stringstream ss;
// This can be used to verify the hash of the underlying computation proto.
// Note that for UserComputation computations, the protobuf is factored in
// the graph hash.
std::string serialized_computation =
ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
computation->computation().proto()));
ss << "\n"
<< "Computation hash: "
<< torch::lazy::HashToString(torch::lazy::Hash(serialized_computation))
<< "\n";

constexpr std::string_view debug_output_prefix =
"Post Compilation Analysis: ";
std::stringstream ss;
ss << "\n"
<< debug_output_prefix
<< "======================================================================"
Expand Down
17 changes: 2 additions & 15 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,13 +1210,7 @@ XLAGraphExecutor::LookupCachedCompile(const torch::lazy::hash_t& hash) {
TORCH_LAZY_COUNTER("UncachedCompile", 1);
return nullptr;
}
std::string serialized_computation =
ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
cached_computation->computation->computation().proto()));
TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(hash)
<< " is computation hash "
<< torch::lazy::HashToString(
torch::lazy::Hash(serialized_computation));
TF_VLOG(5) << "Graph hash: " << torch::lazy::HashToString(hash);
TORCH_LAZY_COUNTER("CachedCompile", 1);
return cached_computation;
}
Expand Down Expand Up @@ -1474,14 +1468,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
<< coll.device << " done!";
TF_VLOG(5) << "Compiled program shape "
<< computations.front()->program_shape().ToString() << std::endl;
std::string serialized_computation =
ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
computations.front()->computation().proto()));
TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(coll.hash)
<< " is computation hash "
<< torch::lazy::HashToString(
torch::lazy::Hash(serialized_computation));

TF_VLOG(5) << "Graph hash: " << torch::lazy::HashToString(coll.hash);
if (use_autosharding) {
const xla::HloModuleProto& computation_proto =
computations.front()->computation().proto();
Expand Down
Loading