Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,8 @@ def _is_on_tpu():
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'


def _is_on_eager_debug_mode():
return xu.getenv_as('XLA_USE_EAGER_DEBUG_MODE', bool, defval=False)


skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU')
skipOnEagerDebug = unittest.skipIf(_is_on_eager_debug_mode(),
skipOnEagerDebug = unittest.skipIf(torch_xla.experimental.is_eager_mode(),
'skip on eager debug mode')


Expand Down Expand Up @@ -604,7 +600,7 @@ def test_rand(self):
self.assertEqual(x.device.type, 'xla')

def test_randperm(self):
x = torch.randperm(3, device=xm.xla_device())
x = torch.randperm(3, device=xm.xla_device(), dtype=torch.int32)
self.assertEqual(x.device.type, 'xla')

def test_randn_like(self):
Expand Down Expand Up @@ -1690,10 +1686,8 @@ def test_binaryop_order(self):
y = torch.rand(5)
self.assertEqual(x + y, y + x)

@unittest.skipIf(
os.environ.get('XLA_USE_EAGER_DEBUG_MODE'),
'Since in eager mode the tensor would be materialized and hence _get_xla_tensors_text would not show the prim::Constant node.'
)
# Since in eager mode the tensor would be materialized and hence _get_xla_tensors_text would not show the prim::Constant node.
@skipOnEagerDebug
def test_pow_constant(self):
t1 = torch.pow(torch.tensor([2.0, 3.0], device=xm.xla_device()), 5)
hlo_text = torch_xla._XLAC._get_xla_tensors_text([t1])
Expand Down Expand Up @@ -2360,16 +2354,13 @@ def test_wait_device_ops(self):
xm.mark_step()
xm.wait_device_ops()
self.assertTrue("ExecuteTime" in met.metric_names() or
"ExecuteChainedTime" in met.metric_names())
"EagerOpExecuteTime" in met.metric_names())


class TestDebuggingUtil(test_utils.XlaTestCase):

@skipOnEagerDebug
def test_get_xla_tensor_debug_info(self):
if xu.getenv_as('XLA_USE_EAGER_DEBUG_MODE', str, '1'):
# ignore this test for eager debug mode since it will
# mess up the IR.
return
device = xm.xla_device()
# test non xla tensor
cpu_t1 = torch.randn(5)
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def _init_xla_lazy_backend():
plugins.use_dynamic_plugins()
plugins.register_installed_plugins()

if os.getenv('XLA_USE_EAGER_DEBUG_MODE', '0') == '1':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just for backward compatibility, since we already have the api now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, it is mostly for backward compatibility. I think I can also tweak it and use it in test to test both eager and non-eager.

from .experimental import eager_mode
eager_mode(True)

from .torch_xla import *

# register all custom kenels and decomp by default
Expand Down
31 changes: 11 additions & 20 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ XLATensorPtr XLATensor::Create(
XLATensor(std::move(ir_value), device, logical_element_type));
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
graph_executor->RegisterTensor(xtensor->data());
if ((UseEagerDebugMode() || graph_executor->UseEagerMode()) &&
!delay_eager_executation) {
if (graph_executor->UseEagerMode() && !delay_eager_executation) {
std::vector<XLATensorPtr> xtensors({xtensor});
graph_executor->ApplyEagerSync(xtensors);
}
Expand Down Expand Up @@ -330,7 +329,8 @@ void XLATensor::SetXlaData(torch::lazy::BackendDataPtr handle, bool sync) {
data()->is_cloned = false;
}

void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) {
void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace,
bool delay_eager_executation) {
data()->handle = nullptr;
data()->tensor_data = std::nullopt;
if (data()->view != nullptr && inplace) {
Expand All @@ -346,11 +346,15 @@ void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) {
AssignIrValue(std::move(ir_value));
TryLimitGraphSize();
}
if (UseEagerDebugMode() && ShouldSyncIrNode()) {
data()->is_cloned = false;

XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
// Update should also be triggered eagerly if configured
if (graph_executor->UseEagerMode() && !delay_eager_executation &&
ShouldSyncIrNode()) {
std::vector<XLATensorPtr> xtensors({c10::make_intrusive<XLATensor>(*this)});
XLAGraphExecutor::Get()->ApplyEagerSync(xtensors);
graph_executor->ApplyEagerSync(xtensors);
}
data()->is_cloned = false;
}

void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value,
Expand All @@ -360,14 +364,7 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value,
ir_value =
torch::lazy::MakeNode<Cast>(ir_value, xla_shape.get().element_type());
}
SetIrValue(std::move(ir_value), /*inplace=*/true);
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();

// in place update should also be triggered eagerly if configured
if (graph_executor->UseEagerMode() && !delay_eager_executation) {
std::vector<XLATensorPtr> xtensors({c10::make_intrusive<XLATensor>(*this)});
graph_executor->ApplyEagerSync(xtensors);
}
SetIrValue(std::move(ir_value), /*inplace=*/true, delay_eager_executation);
}

void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const {
Expand Down Expand Up @@ -661,12 +658,6 @@ void XLATensor::ApplyPendingGraph() {
}
}

bool XLATensor::UseEagerDebugMode() {
static const bool use_eager_debug_mode =
runtime::sys_util::GetEnvBool("XLA_USE_EAGER_DEBUG_MODE", false);
return use_eager_debug_mode;
}

bool XLATensor::ShouldSyncIrNode() {
if (!this->data()->ir_value) {
return false;
Expand Down
5 changes: 2 additions & 3 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class XLATensor : public torch::lazy::LazyTensor {
// internal state of the object.
// TODO(alanwaketan): Reuse the upstream ones once Functionalization is done.
torch::lazy::Value GetIrValue() const;
void SetIrValue(torch::lazy::Value ir_value, bool inplace = true);
void SetIrValue(torch::lazy::Value ir_value, bool inplace = true,
bool delay_eager_executation = false);
void SetInPlaceIrValue(torch::lazy::Value ir_value,
bool delay_eager_executation = false);

Expand Down Expand Up @@ -334,8 +335,6 @@ class XLATensor : public torch::lazy::LazyTensor {
const at::Tensor& tensor,
const torch::lazy::BackendDevice& device) const final;

static bool UseEagerDebugMode();

bool ShouldSyncIrNode();

// We store two shared_ptr of Data in a XLATensor.
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2035,9 +2035,9 @@ void max_out(XLATensorPtr& max, XLATensorPtr& max_values,
torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank());
torch::lazy::NodePtr node = torch::lazy::MakeNode<MaxInDim>(
input->GetIrValue(), canonical_dim, keepdim);
max->SetIrValue(torch::lazy::Value(node, 0),
max->SetIrValue(torch::lazy::Value(node, 0), /*inplace=*/true,
/*delay_eager_executation=*/true);
max_values->SetIrValue(torch::lazy::Value(node, 1),
max_values->SetIrValue(torch::lazy::Value(node, 1), /*inplace=*/true,
/*delay_eager_executation=*/true);
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
if (graph_executor->UseEagerMode()) {
Expand Down Expand Up @@ -2143,9 +2143,9 @@ void min_out(XLATensorPtr& min, XLATensorPtr& min_indices,
torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank());
torch::lazy::NodePtr node = torch::lazy::MakeNode<MinInDim>(
input->GetIrValue(), canonical_dim, keepdim);
min->SetIrValue(torch::lazy::Value(node, 0),
min->SetIrValue(torch::lazy::Value(node, 0), /*inplace=*/true,
/*delay_eager_executation=*/true);
min_indices->SetIrValue(torch::lazy::Value(node, 1),
min_indices->SetIrValue(torch::lazy::Value(node, 1), /*inplace=*/true,
/*delay_eager_executation=*/true);
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
if (graph_executor->UseEagerMode()) {
Expand Down Expand Up @@ -2296,13 +2296,13 @@ std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> native_batch_norm(
running_mean->SetIrValue(
torch::lazy::MakeNode<LinearInterpolation>(
mean->GetIrValue(), running_mean->GetIrValue(), momentum),
/*delay_eager_executation=*/true);
/*inplace=*/true, /*delay_eager_executation=*/true);
}
if (running_var) {
running_var->SetIrValue(
torch::lazy::MakeNode<LinearInterpolation>(
torch::lazy::Value(node, 2), running_var->GetIrValue(), momentum),
/*delay_eager_executation=*/true);
/*inplace=*/true, /*delay_eager_executation=*/true);
}
} else {
at::Tensor at_input = bridge::AtenFromXlaTensor(input);
Expand Down