diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 790b5deae7bd..732f9c5a6615 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -16,6 +16,7 @@ #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/ir_util.h" +#include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { namespace { @@ -120,6 +121,13 @@ void DebugUtil::SaveTensorsGraphInfo(const char* name, "XLA_SAVE_TENSORS_FILE", "", GetCurrentDevice().ordinal()); if (!save_file.empty()) { static std::mutex lock; + if (format == DebugUtil::GraphFormat::kHlo && indices->size() > 0) { + // Dumping the HLO might access the placeholder data created during + // previous execution. We need to wait for last execution to finish before + // proceeding. + torch::lazy::BackendDevice device = tensors[(*indices)[0]]->GetDevice(); + XLAGraphExecutor::Get()->WaitDeviceOps({device.toString()}); + } std::string info = GetTensorsGraphInfo(tensors, indices, format); std::lock_guard guard(lock); std::ofstream graph_file(save_file, std::ios_base::app);