diff --git a/test/test_operations.py b/test/test_operations.py index e3b3d212e19d..843bb52f609f 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -65,12 +65,8 @@ def _is_on_tpu(): return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' -def _is_on_eager_debug_mode(): - return xu.getenv_as('XLA_USE_EAGER_DEBUG_MODE', bool, defval=False) - - skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU') -skipOnEagerDebug = unittest.skipIf(_is_on_eager_debug_mode(), +skipOnEagerDebug = unittest.skipIf(torch_xla.experimental.is_eager_mode(), 'skip on eager debug mode') @@ -604,7 +600,7 @@ def test_rand(self): self.assertEqual(x.device.type, 'xla') def test_randperm(self): - x = torch.randperm(3, device=xm.xla_device()) + x = torch.randperm(3, device=xm.xla_device(), dtype=torch.int32) self.assertEqual(x.device.type, 'xla') def test_randn_like(self): @@ -1690,10 +1686,8 @@ def test_binaryop_order(self): y = torch.rand(5) self.assertEqual(x + y, y + x) - @unittest.skipIf( - os.environ.get('XLA_USE_EAGER_DEBUG_MODE'), - 'Since in eager mode the tensor would be materialized and hence _get_xla_tensors_text would not show the prim::Constant node.' - ) + # Since in eager mode the tensor would be materialized and hence _get_xla_tensors_text would not show the prim::Constant node. + @skipOnEagerDebug def test_pow_constant(self): t1 = torch.pow(torch.tensor([2.0, 3.0], device=xm.xla_device()), 5) hlo_text = torch_xla._XLAC._get_xla_tensors_text([t1]) @@ -2360,16 +2354,13 @@ def test_wait_device_ops(self): xm.mark_step() xm.wait_device_ops() self.assertTrue("ExecuteTime" in met.metric_names() or - "ExecuteChainedTime" in met.metric_names()) + "EagerOpExecuteTime" in met.metric_names()) class TestDebuggingUtil(test_utils.XlaTestCase): + @skipOnEagerDebug def test_get_xla_tensor_debug_info(self): - if xu.getenv_as('XLA_USE_EAGER_DEBUG_MODE', str, '1'): - # ignore this test for eager debug mode since it will - # mess up the IR. - return device = xm.xla_device() # test non xla tensor cpu_t1 = torch.randn(5) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index d1639e5715c8..bcbecac5636e 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -243,6 +243,10 @@ def _init_xla_lazy_backend(): plugins.use_dynamic_plugins() plugins.register_installed_plugins() +if os.getenv('XLA_USE_EAGER_DEBUG_MODE', '0') == '1': + from .experimental import eager_mode + eager_mode(True) + from .torch_xla import * # register all custom kenels and decomp by default diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 9baaeb04a535..de75a3737530 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -86,8 +86,7 @@ XLATensorPtr XLATensor::Create( XLATensor(std::move(ir_value), device, logical_element_type)); XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); graph_executor->RegisterTensor(xtensor->data()); - if ((UseEagerDebugMode() || graph_executor->UseEagerMode()) && - !delay_eager_executation) { + if (graph_executor->UseEagerMode() && !delay_eager_executation) { std::vector xtensors({xtensor}); graph_executor->ApplyEagerSync(xtensors); } @@ -330,7 +329,8 @@ void XLATensor::SetXlaData(torch::lazy::BackendDataPtr handle, bool sync) { data()->is_cloned = false; } -void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) { +void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace, + bool delay_eager_executation) { data()->handle = nullptr; data()->tensor_data = std::nullopt; if (data()->view != nullptr && inplace) { @@ -346,11 +346,15 @@ void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) { AssignIrValue(std::move(ir_value)); TryLimitGraphSize(); } - if (UseEagerDebugMode() && ShouldSyncIrNode()) { + data()->is_cloned = false; + + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + // Update should also be triggered eagerly if configured + if (graph_executor->UseEagerMode() && !delay_eager_executation && + ShouldSyncIrNode()) { std::vector xtensors({c10::make_intrusive(*this)}); - XLAGraphExecutor::Get()->ApplyEagerSync(xtensors); + graph_executor->ApplyEagerSync(xtensors); } - data()->is_cloned = false; } void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value, @@ -360,14 +364,7 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value, ir_value = torch::lazy::MakeNode(ir_value, xla_shape.get().element_type()); } - SetIrValue(std::move(ir_value), /*inplace=*/true); - XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); - - // in place update should also be triggered eagerly if configured - if (graph_executor->UseEagerMode() && !delay_eager_executation) { - std::vector xtensors({c10::make_intrusive(*this)}); - graph_executor->ApplyEagerSync(xtensors); - } + SetIrValue(std::move(ir_value), /*inplace=*/true, delay_eager_executation); } void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const { @@ -661,12 +658,6 @@ void XLATensor::ApplyPendingGraph() { } } -bool XLATensor::UseEagerDebugMode() { - static const bool use_eager_debug_mode = - runtime::sys_util::GetEnvBool("XLA_USE_EAGER_DEBUG_MODE", false); - return use_eager_debug_mode; -} - bool XLATensor::ShouldSyncIrNode() { if (!this->data()->ir_value) { return false; diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index d837f2c2ab58..725a851255d2 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -231,7 +231,8 @@ class XLATensor : public torch::lazy::LazyTensor { // internal state of the object. // TODO(alanwaketan): Reuse the upstream ones once Functionalization is done. torch::lazy::Value GetIrValue() const; - void SetIrValue(torch::lazy::Value ir_value, bool inplace = true); + void SetIrValue(torch::lazy::Value ir_value, bool inplace = true, + bool delay_eager_executation = false); void SetInPlaceIrValue(torch::lazy::Value ir_value, bool delay_eager_executation = false); @@ -334,8 +335,6 @@ class XLATensor : public torch::lazy::LazyTensor { const at::Tensor& tensor, const torch::lazy::BackendDevice& device) const final; - static bool UseEagerDebugMode(); - bool ShouldSyncIrNode(); // We store two shared_ptr of Data in a XLATensor. diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index ab58576d6242..dce5c6f6f3d8 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2035,9 +2035,9 @@ void max_out(XLATensorPtr& max, XLATensorPtr& max_values, torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()); torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), canonical_dim, keepdim); - max->SetIrValue(torch::lazy::Value(node, 0), + max->SetIrValue(torch::lazy::Value(node, 0), /*inplace=*/true, /*delay_eager_executation=*/true); - max_values->SetIrValue(torch::lazy::Value(node, 1), + max_values->SetIrValue(torch::lazy::Value(node, 1), /*inplace=*/true, /*delay_eager_executation=*/true); XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); if (graph_executor->UseEagerMode()) { @@ -2143,9 +2143,9 @@ void min_out(XLATensorPtr& min, XLATensorPtr& min_indices, torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()); torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), canonical_dim, keepdim); - min->SetIrValue(torch::lazy::Value(node, 0), + min->SetIrValue(torch::lazy::Value(node, 0), /*inplace=*/true, /*delay_eager_executation=*/true); - min_indices->SetIrValue(torch::lazy::Value(node, 1), + min_indices->SetIrValue(torch::lazy::Value(node, 1), /*inplace=*/true, /*delay_eager_executation=*/true); XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); if (graph_executor->UseEagerMode()) { @@ -2296,13 +2296,13 @@ std::tuple native_batch_norm( running_mean->SetIrValue( torch::lazy::MakeNode( mean->GetIrValue(), running_mean->GetIrValue(), momentum), - /*delay_eager_executation=*/true); + /*inplace=*/true, /*delay_eager_executation=*/true); } if (running_var) { running_var->SetIrValue( torch::lazy::MakeNode( torch::lazy::Value(node, 2), running_var->GetIrValue(), momentum), - /*delay_eager_executation=*/true); + /*inplace=*/true, /*delay_eager_executation=*/true); } } else { at::Tensor at_input = bridge::AtenFromXlaTensor(input);