From 6d48ff5bfb86d1d609a916375e46699d10f6a7a3 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Tue, 11 Jun 2024 13:12:30 -0700 Subject: [PATCH] Enable eager mode for PyTorch/XLA (#7234) --- docs/source/index.rst | 4 +++ examples/eager/train_decoder_only_eager.py | 12 +++++++++ test/debug_tool/test_pt_xla_debug.py | 12 +++++++++ test/test_metrics.py | 19 ++++++++++++++ test/tpu/run_tests.sh | 1 + torch_xla/csrc/debug_util.cpp | 13 ++++++++++ torch_xla/csrc/init_python_bindings.cpp | 5 ++++ torch_xla/csrc/runtime/computation_client.cc | 12 +++++++++ torch_xla/csrc/runtime/computation_client.h | 25 +++++++++++-------- .../csrc/runtime/pjrt_computation_client.cc | 12 +++++++-- torch_xla/csrc/tensor.cpp | 7 +++--- torch_xla/csrc/xla_graph_executor.cpp | 18 ++++++++----- torch_xla/csrc/xla_graph_executor.h | 7 ++++++ torch_xla/experimental/__init__.py | 5 ++++ torch_xla/experimental/eager.py | 9 +++++++ 15 files changed, 139 insertions(+), 22 deletions(-) create mode 100644 examples/eager/train_decoder_only_eager.py create mode 100644 torch_xla/experimental/eager.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 49f3c831c22d..cf0eb8b01256 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -73,6 +73,10 @@ spmd .. autoclass:: HybridMesh .. autoclass:: ShardingSpec +experimental +---------------------------------- +.. automodule:: torch_xla.experimental +.. autofunction:: eager_mode debug ---------------------------------- diff --git a/examples/eager/train_decoder_only_eager.py b/examples/eager/train_decoder_only_eager.py new file mode 100644 index 000000000000..67ead6cc8300 --- /dev/null +++ b/examples/eager/train_decoder_only_eager.py @@ -0,0 +1,12 @@ +import sys +import os +example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) +sys.path.append(example_folder) +from train_decoder_only_base import TrainDecoderOnlyBase + +import torch_xla + +if __name__ == '__main__': + torch_xla.experimental.eager_mode(True) + base = TrainDecoderOnlyBase() + base.start_training() diff --git a/test/debug_tool/test_pt_xla_debug.py b/test/debug_tool/test_pt_xla_debug.py index ed83b236a115..ecc0f2c05201 100644 --- a/test/debug_tool/test_pt_xla_debug.py +++ b/test/debug_tool/test_pt_xla_debug.py @@ -29,6 +29,18 @@ def setUpClass(cls): assert False, "This test should be run with PT_XLA_DEBUG_FILE" open(cls.debug_file_name, 'w').close() + def test_eager_mark_step(self): + torch_xla.experimental.eager_mode(True) + device = xm.xla_device() + t1 = torch.randn(5, 9, device=device) + xm.mark_step() + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + # We expect PT_XLA_BUDEG not to output anything under the eager mode + self.assertEqual(len(lines), 0) + torch_xla.experimental.eager_mode(False) + open(self.debug_file_name, 'w').close() + def test_user_mark_step(self): device = xm.xla_device() t1 = torch.randn(2, 2, device=device) diff --git a/test/test_metrics.py b/test/test_metrics.py index 409876d8d9de..cabd2e768b87 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -52,6 +52,25 @@ def test_tracing_time_metrics(self): self.assertIn('LazyTracing', met.metric_names()) self.assertGreater(met.metric_data('LazyTracing')[0], 1) + def test_eager_metrics(self): + torch_xla.experimental.eager_mode(True) + xla_device = xm.xla_device() + met.clear_all() + t1 = torch.tensor(156, device=xla_device) + t2 = t1 + 100 + xm.wait_device_ops() + self.assertIn('EagerOpCompileTime', met.metric_names()) + # one for cosntant, one for add + self.assertEqual(met.metric_data('EagerOpCompileTime')[0], 2) + self.assertIn('EagerOpExecuteTime', met.metric_names()) + # one for add + self.assertEqual(met.metric_data('EagerOpExecuteTime')[0], 2) + # mark_step should be a no-op + xm.mark_step() + self.assertNotIn('CompileTime', met.metric_names()) + self.assertNotIn('ExecuteTime', met.metric_names()) + torch_xla.experimental.eager_mode(False) + def test_short_metrics_report_default_list(self): xla_device = xm.xla_device() t1 = torch.tensor(1456, device=xla_device) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index ddd439d1c60f..c7f12f93e149 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -36,3 +36,4 @@ python3 examples/data_parallel/train_resnet_xla_ddp.py python3 examples/fsdp/train_decoder_only_fsdp_v2.py python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py python3 examples/train_resnet_amp.py +python3 examples/eager/train_decoder_only_eager.py diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 11843a39a59b..c82369045bf3 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -267,6 +267,12 @@ void DebugUtil::analyze_graph_execution_python_frame( return; } + // don't output analysis for eager mode execution/compilation + if (XLAGraphExecutor::Get()->UseEagerMode() && + source != GraphAnalysisSource::DynamoExecution) { + return; + } + if (pt_xla_debug_level <= 1 && source != GraphAnalysisSource::Compilation) { // for debug level <=1, only output compilation analysis in this function. return; @@ -385,6 +391,13 @@ void DebugUtil::post_compilation_analysis( if (pt_xla_debug_level <= 0 || !is_master_process) { return; } + + // don't output analysis for eager mode execution/compilation. + // TODO(JackCaoG): enable this for eager+dynamo + if (XLAGraphExecutor::Get()->UseEagerMode()) { + return; + } + static const std::string debug_output_prefix = "Post Compilation Analysis: "; std::stringstream ss; ss << "\n" diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c3a062565fe1..de050ead8f87 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2403,6 +2403,11 @@ void InitXlaModuleBindings(py::module m) { }); m.def("_get_xla_enable_device_data_cache", []() { return FLAGS_torch_lazy_enable_device_data_cache; }); + m.def("_set_use_eager_mode", [](bool use_eager_mode) { + XLAGraphExecutor::Get()->SetUseEagerMode(use_eager_mode); + }); + m.def("_get_use_eager_mode", + []() { return XLAGraphExecutor::Get()->UseEagerMode(); }); m.def("_replace_xla_tensor", [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); diff --git a/torch_xla/csrc/runtime/computation_client.cc b/torch_xla/csrc/runtime/computation_client.cc index f29bfb90a94c..af304fc3ec6b 100644 --- a/torch_xla/csrc/runtime/computation_client.cc +++ b/torch_xla/csrc/runtime/computation_client.cc @@ -77,12 +77,24 @@ metrics::Metric* ComputationClient::CompileMetric() { return metric; } +metrics::Metric* ComputationClient::EagerCompileMetric() { + static metrics::Metric* metric = + new metrics::Metric("EagerOpCompileTime", metrics::MetricFnTime); + return metric; +} + metrics::Metric* ComputationClient::ExecuteMetric() { static metrics::Metric* metric = new metrics::Metric("ExecuteTime", metrics::MetricFnTime); return metric; } +metrics::Metric* ComputationClient::EagerExecuteMetric() { + static metrics::Metric* metric = + new metrics::Metric("EagerOpExecuteTime", metrics::MetricFnTime); + return metric; +} + metrics::Metric* ComputationClient::ExecuteReplicatedMetric() { static metrics::Metric* metric = new metrics::Metric("ExecuteReplicatedTime", metrics::MetricFnTime); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index a66ae2a7fa42..105d19f64c01 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -46,6 +46,7 @@ class XlaCoordinator; // ComputationClient. struct ClientExecuteOptions { bool explode_tuple{true}; + bool eager_mode{false}; }; class ComputationClient { @@ -213,16 +214,14 @@ class ComputationClient { // of torch::lazy::Computation? struct CompileInstance { CompileInstance() = default; - CompileInstance(xla::XlaComputation computation, - std::string compilation_device, - std::vector devices, - const xla::Shape* output_shape, - bool parameter_is_tupled_arguments = false, - bool is_sharded = false, - bool allow_spmd_sharding_propagation_to_output = true, - bool use_auto_spmd_partitioning = false, - std::vector auto_spmd_mesh_shape = {}, - std::vector auto_spmd_mesh_ids = {}) + CompileInstance( + xla::XlaComputation computation, std::string compilation_device, + std::vector devices, const xla::Shape* output_shape, + bool parameter_is_tupled_arguments = false, bool is_sharded = false, + bool allow_spmd_sharding_propagation_to_output = true, + bool use_auto_spmd_partitioning = false, + std::vector auto_spmd_mesh_shape = {}, + std::vector auto_spmd_mesh_ids = {}, bool eager_mode = false) : computation(std::move(computation)), compilation_device(std::move(compilation_device)), devices(std::move(devices)), @@ -233,7 +232,8 @@ class ComputationClient { allow_spmd_sharding_propagation_to_output), use_auto_spmd_partitioning(use_auto_spmd_partitioning), auto_spmd_mesh_shape(auto_spmd_mesh_shape), - auto_spmd_mesh_ids(auto_spmd_mesh_ids) {} + auto_spmd_mesh_ids(auto_spmd_mesh_ids), + eager_mode(eager_mode) {} xla::XlaComputation computation; std::string compilation_device; @@ -245,6 +245,7 @@ class ComputationClient { bool use_auto_spmd_partitioning; std::vector auto_spmd_mesh_shape; std::vector auto_spmd_mesh_ids; + bool eager_mode; }; struct ExecuteComputationOptions : public ClientExecuteOptions {}; @@ -430,7 +431,9 @@ class ComputationClient { static metrics::Metric* TransferToDeviceTransformMetric(); static metrics::Metric* TransferFromDeviceMetric(); static metrics::Metric* CompileMetric(); + static metrics::Metric* EagerCompileMetric(); static metrics::Metric* ExecuteMetric(); + static metrics::Metric* EagerExecuteMetric(); static metrics::Metric* ExecuteReplicatedMetric(); static metrics::Metric* ExecuteParallelMetric(); static metrics::Metric* ExecuteChainedMetric(); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index fe625f6d1613..ec1848b9b064 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -531,7 +531,11 @@ std::vector PjRtComputationClient::TransferFromDevice( std::vector PjRtComputationClient::Compile( std::vector instances) { - metrics::TimedSection timed(CompileMetric()); + auto metrics_fn = CompileMetric; + if (instances[0].eager_mode) { + metrics_fn = EagerCompileMetric; + } + metrics::TimedSection timed(metrics_fn()); tsl::profiler::TraceMe activity("PjRtComputationClient::Compile", tsl::profiler::TraceMeLevel::kInfo); std::vector computations; @@ -695,7 +699,11 @@ PjRtComputationClient::ExecuteComputation( // Shared ownership of the timed section ensures that it will only get logged // once both `ExecuteComputation` and the async work in `ExecuteSharded` are // complete; a copy is held from the lambda that releases it when done. - auto timed = std::make_shared(ExecuteMetric()); + auto metrics_fn = ExecuteMetric; + if (options.eager_mode) { + metrics_fn = EagerExecuteMetric; + } + auto timed = std::make_shared(metrics_fn()); tsl::profiler::TraceMe activity("PjRtComputationClient::ExecuteComputation", tsl::profiler::TraceMeLevel::kInfo); TF_VLOG(1) << "Executing PjRt computation on " << device; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 498f3b93536a..8f516f9016a4 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -83,10 +83,11 @@ XLATensorPtr XLATensor::Create( std::optional logical_element_type) { XLATensorPtr xtensor = c10::make_intrusive( XLATensor(std::move(ir_value), device, logical_element_type)); - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); - if (UseEagerDebugMode()) { + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + graph_executor->RegisterTensor(xtensor->data()); + if (UseEagerDebugMode() || graph_executor->UseEagerMode()) { std::vector xtensors({xtensor}); - XLAGraphExecutor::Get()->ApplyEagerSync(xtensors); + graph_executor->ApplyEagerSync(xtensors); } return xtensor; } diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 55507050a7ed..74c3270a9666 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -79,7 +79,7 @@ bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) { XLAGraphExecutor::ComputationCache* CreateComputationCache() { static const size_t kMaxCacheSize = - runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024); + runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 2048); static const bool readonlyPersistentCache = runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_READ_ONLY", false); static std::string persistentCacheDir = @@ -814,6 +814,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( std::vector results; if (async->cached_computation->is_sharded) { + // TODO(JackCaoG): handle eager mode std::vector devices = runtime::GetComputationClient()->GetLocalDevices(); runtime::ComputationClient::ExecuteReplicatedOptions execute_options; @@ -1088,7 +1089,8 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( std::shared_ptr async = std::make_shared( coll, std::move(parameters_data), std::move(tensors_data), std::move(cached_computation)); - auto syncfn = [async, hash = coll->hash, sharding_specs = sharding_specs]() { + auto syncfn = [async, hash = coll->hash, sharding_specs = sharding_specs, + use_eager_mode = UseEagerMode()]() { try { std::vector results; // Execute replicated if the compiled computation is partitioned. @@ -1117,9 +1119,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) << " on device " << async->device << " ..."; - results = torch::lazy::getBackend()->ExecuteComputation( - async->cached_computation->computation, async->parameters_data, - async->device); + std::vector outputs = + runtime::GetComputationClient()->ExecuteComputation( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), async->device.toString(), + {/*explode_tuple=*/true, + /*eager_mode=*/use_eager_mode}); + results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteComputation", 1); TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) << " on device " @@ -1372,7 +1378,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( runtime::GetComputationClient()->GetCompilationDevices( coll.device.toString(), devices), &shape, should_wrap_parameter, is_sharded}); - + instances.front().eager_mode = UseEagerMode(); if (use_autosharding) { TF_VLOG(5) << "use_auto_spmd_partitioning is set."; TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode."; diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 3baf7d830634..bc057193e20c 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -187,6 +187,12 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { void ClearPendingIrs(std::vector tensors, const torch::lazy::BackendDevice& device); + void SetUseEagerMode(bool use_eager_mode) { + use_eager_mode_ = use_eager_mode; + } + + bool UseEagerMode() { return use_eager_mode_; } + private: // This is just to group results from compile(). Since our computation is // different, we don't reuse the upstream CompilationResult. @@ -361,6 +367,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { const SyncTensorsConfig& config, bool warm_up_cache_only = false); ComputationCache* computation_cache_; + bool use_eager_mode_ = false; }; } // namespace torch_xla diff --git a/torch_xla/experimental/__init__.py b/torch_xla/experimental/__init__.py index e69de29bb2d1..68bd36dc06df 100644 --- a/torch_xla/experimental/__init__.py +++ b/torch_xla/experimental/__init__.py @@ -0,0 +1,5 @@ +from .eager import eager_mode + +__all__ = [ + "eager_mode", +] \ No newline at end of file diff --git a/torch_xla/experimental/eager.py b/torch_xla/experimental/eager.py new file mode 100644 index 000000000000..37861fdd8bf7 --- /dev/null +++ b/torch_xla/experimental/eager.py @@ -0,0 +1,9 @@ +import torch_xla + + +def eager_mode(enable: bool): + """Configure torch_xla's default executation mode. + Under eager mode only functions that was `torch_xla.compile`d will be + traced and compiled. Other torch ops will be executed eagerly. + """ + torch_xla._XLAC._set_use_eager_mode(enable)