diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 15bff98cb072..8a06608c8b9c 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -815,39 +815,7 @@ XLAGraphExecutor::PostOrderData XLAGraphExecutor::RunPostOrder( SyncTensorCollection* coll) { tensorflow::profiler::TraceMe activity( "RunPostOrder", tensorflow::profiler::TraceMeLevel::kInfo); - std::vector roots; - roots.reserve(ir_values.size()); - for (auto ir_value : ir_values) { - roots.push_back(ir_value.node.get()); - } - PostOrderData po_data; - po_data.post_order = - torch::lazy::Util::ComputePostOrder(roots, &po_data.emission_map); - std::unordered_map - data_handles; - - for (auto node : po_data.post_order) { - const auto backend_data = - torch::lazy::getBackend()->GetComputationDataFromNode(node); - if (backend_data != nullptr) { - /* Acceptable race condition: HasValue may return false. This is OK - * since the conditional barrier is a performance optimization. */ - if (!backend_data->HasValue()) { - TensorCollectionBarrier(coll); - } - xla::ComputationClient::Data::OpaqueHandle handle = - backend_data->GetHandle(); - auto it = data_handles.find(handle); - if (it != data_handles.end()) { - po_data.parameter_sequence.push_back(it->second); - } else { - po_data.parameter_sequence.push_back(po_data.parameters_data.size()); - data_handles[handle] = po_data.parameters_data.size(); - po_data.parameters_data.push_back(backend_data); - } - } - } - return po_data; + return torch::lazy::LazyGraphExecutor::RunPostOrder(ir_values, coll); } XLAGraphExecutor::ComputationCache::TypePtr diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index c573e2313a5e..44e4087486db 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -56,6 +56,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // that is affected by the view tensor. void ApplyEagerSync(std::vector& tensors); + // We don't use the upstream GetDeviceDataIrValue to have the + // xla::PrimitiveType. torch::lazy::Value GetDeviceDataIrValue( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device); @@ -243,6 +245,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { absl::Span indices, const std::vector& tensor_data_vec); + // We don't use upstream ExtractIRAndPrepareTensorData as we need to + // instantiate xla::shape. void ExtractIRAndPrepareXlaData_( std::vector* tensors, const SyncTensorsConfig& config, const absl::Span indices, @@ -271,8 +275,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::string device, ComputationCache::TypePtr cached_computation, const std::vector& tensor_data_vec); + // Override to enable profiler. PostOrderData RunPostOrder(const std::vector& ir_values, - SyncTensorCollection* coll); + SyncTensorCollection* coll) final; // We don't use the upstream LookupCachedCompile since // our CachedComputation is different from upstream. @@ -290,6 +295,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { const std::vector& tensors, absl::Span indices, LoweringContext* lowering_ctx); + // We don't use upstream Compile to have BuildInputOutputAliases. CompilationResult Compile(const std::vector& tensors, absl::Span devices, const SyncTensorCollection& coll,