From f2ca55e417db0905acab5109eca747d0b77104f8 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Mon, 12 Dec 2022 03:42:19 +0000 Subject: [PATCH 1/6] Override RunPostOrder --- torch_xla/csrc/xla_graph_executor.cpp | 34 +-------------------------- torch_xla/csrc/xla_graph_executor.h | 3 ++- 2 files changed, 3 insertions(+), 34 deletions(-) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 15bff98cb072..8a06608c8b9c 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -815,39 +815,7 @@ XLAGraphExecutor::PostOrderData XLAGraphExecutor::RunPostOrder( SyncTensorCollection* coll) { tensorflow::profiler::TraceMe activity( "RunPostOrder", tensorflow::profiler::TraceMeLevel::kInfo); - std::vector roots; - roots.reserve(ir_values.size()); - for (auto ir_value : ir_values) { - roots.push_back(ir_value.node.get()); - } - PostOrderData po_data; - po_data.post_order = - torch::lazy::Util::ComputePostOrder(roots, &po_data.emission_map); - std::unordered_map - data_handles; - - for (auto node : po_data.post_order) { - const auto backend_data = - torch::lazy::getBackend()->GetComputationDataFromNode(node); - if (backend_data != nullptr) { - /* Acceptable race condition: HasValue may return false. This is OK - * since the conditional barrier is a performance optimization. */ - if (!backend_data->HasValue()) { - TensorCollectionBarrier(coll); - } - xla::ComputationClient::Data::OpaqueHandle handle = - backend_data->GetHandle(); - auto it = data_handles.find(handle); - if (it != data_handles.end()) { - po_data.parameter_sequence.push_back(it->second); - } else { - po_data.parameter_sequence.push_back(po_data.parameters_data.size()); - data_handles[handle] = po_data.parameters_data.size(); - po_data.parameters_data.push_back(backend_data); - } - } - } - return po_data; + return torch::lazy::LazyGraphExecutor::RunPostOrder(ir_values, coll); } XLAGraphExecutor::ComputationCache::TypePtr diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index c573e2313a5e..60005617767a 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -271,8 +271,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::string device, ComputationCache::TypePtr cached_computation, const std::vector& tensor_data_vec); + // Override to enable profiler. PostOrderData RunPostOrder(const std::vector& ir_values, - SyncTensorCollection* coll); + SyncTensorCollection* coll) final; // We don't use the upstream LookupCachedCompile since // our CachedComputation is different from upstream. From dcf769085a4fc8e995d136cc584c2473849d4cd8 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Mon, 12 Dec 2022 03:48:49 +0000 Subject: [PATCH 2/6] Adds more comments --- torch_xla/csrc/xla_graph_executor.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 60005617767a..e64985e45b10 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -56,6 +56,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // that is affected by the view tensor. void ApplyEagerSync(std::vector& tensors); + // We don't use the upstream GetDeviceDataIrValue to have the xla::PrimitiveType. torch::lazy::Value GetDeviceDataIrValue( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device); @@ -291,6 +292,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { const std::vector& tensors, absl::Span indices, LoweringContext* lowering_ctx); + // We don't use upstream Compile to have BuildInputOutputAliases. CompilationResult Compile(const std::vector& tensors, absl::Span devices, const SyncTensorCollection& coll, From 8a92340f43d061f68fa826c3083aec389418ca27 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Mon, 12 Dec 2022 04:06:52 +0000 Subject: [PATCH 3/6] Add one more commnet --- torch_xla/csrc/xla_graph_executor.h | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index e64985e45b10..4168c083049c 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -244,6 +244,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { absl::Span indices, const std::vector& tensor_data_vec); + // We don't use upstream ExtractIRAndPrepareTensorData as we need to instantiate xla::shape. void ExtractIRAndPrepareXlaData_( std::vector* tensors, const SyncTensorsConfig& config, const absl::Span indices, From ceb39b6c9d6cd9db7e59985e111c5fb7682638b8 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Mon, 12 Dec 2022 04:08:41 +0000 Subject: [PATCH 4/6] Fix linters --- torch_xla/csrc/xla_graph_executor.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 4168c083049c..44e4087486db 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -56,7 +56,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // that is affected by the view tensor. void ApplyEagerSync(std::vector& tensors); - // We don't use the upstream GetDeviceDataIrValue to have the xla::PrimitiveType. + // We don't use the upstream GetDeviceDataIrValue to have the + // xla::PrimitiveType. torch::lazy::Value GetDeviceDataIrValue( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device); @@ -244,7 +245,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { absl::Span indices, const std::vector& tensor_data_vec); - // We don't use upstream ExtractIRAndPrepareTensorData as we need to instantiate xla::shape. + // We don't use upstream ExtractIRAndPrepareTensorData as we need to + // instantiate xla::shape. void ExtractIRAndPrepareXlaData_( std::vector* tensors, const SyncTensorsConfig& config, const absl::Span indices, From adab74eee6c9d5716d0b287844607538345d6c7c Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Mon, 12 Dec 2022 04:15:12 +0000 Subject: [PATCH 5/6] Add .torch_pin --- torch_patches/.torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 000000000000..8cbec8816738 --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#90680 From cfc77df9ecb49dded982de28e2d63abcbd8f84f9 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 13 Dec 2022 18:12:09 +0000 Subject: [PATCH 6/6] Revert "Add .torch_pin" This reverts commit adab74eee6c9d5716d0b287844607538345d6c7c. --- torch_patches/.torch_pin | 1 - 1 file changed, 1 deletion(-) delete mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin deleted file mode 100644 index 8cbec8816738..000000000000 --- a/torch_patches/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#90680