Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,6 +1757,35 @@ def test_wait_device_ops(self):
"ExecuteChainedTime" in met.metric_names())


class TestDebuggingUtil(test_utils.XlaTestCase):

def test_get_xla_tensor_debug_info(self):
if xu.getenv_as('XLA_USE_EAGER_DEBUG_MODE', str, '1'):
# ignore this test for eager debug mode since it will
# mess up the IR.
return
device = xm.xla_device()
# test non xla tensor
cpu_t1 = torch.randn(5)
cpu_t1_info = torch_xla._XLAC._get_xla_tensor_debug_info(cpu_t1)
self.assertIn('Not a XLATensor', cpu_t1_info)

# test a tensor with IR
t1 = cpu_t1.to(device)
t2 = t1 + 5
t2_info = torch_xla._XLAC._get_xla_tensor_debug_info(t2)
self.assertIn('XLA Shape: f32[5]', t2_info)
self.assertIn('aten::add', t2_info)
self.assertIn('XLAData: None', t2_info)

# after makr_step XLAData should present
xm.mark_step()
t2_info_new = torch_xla._XLAC._get_xla_tensor_debug_info(t2)
self.assertNotIn('XLAData: None', t2_info_new)
self.assertIn('Data Shape: f32[5]', t2_info_new)
self.assertIn('IR: None', t2_info_new)


class TestOpBuilder(test_utils.XlaTestCase):

def runOpBuilderTest(self,
Expand Down
45 changes: 45 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,47 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, get_stable_hlo);
}

std::string GetXLATensorDebugInfo(const at::Tensor& tensor) {
auto xtensor = bridge::TryGetXlaTensor(tensor);
if (!xtensor) {
return "Not a XLATensor\n";
}
std::stringstream ss;
ss << "XLATensor {\n";
ss << "TensorID: " << xtensor->GetUniqueId() << "\n";
ss << "Device: " << xtensor->GetDevice() << "\n";
ss << "XLA Shape: " << xtensor->shape().get().ToString() << "\n";

torch::lazy::Value ir_value = xtensor->CurrentIrValue();
ss << "IR: ";
if (ir_value) {
ss << ir_value.node->ToString() << "\n";
} else {
ss << "None\n";
}

torch::lazy::BackendDataPtr handle = xtensor->CurrentDataHandle();
ss << "XLAData: ";
if (handle) {
auto data = UnwrapXlaData(handle);
ss << "\n Data Device: " << data->device() << "\n";
ss << " Data Shape: " << data->shape().ToString() << "\n";
} else {
ss << "None\n";
}

auto at_tensor = xtensor->CurrentTensorData();
ss << "Tensor on host: ";
if (at_tensor) {
ss << " with size " << at_tensor->sizes() << "\n";
} else {
ss << "None\n";
}

ss << "}\n";
return ss.str();
}

std::string GetLiveTensorsReport(size_t nodes_threshold,
const std::string& device_str) {
auto opt_device = GetOptionalDevice(device_str);
Expand Down Expand Up @@ -853,6 +894,10 @@ void InitXlaModuleBindings(py::module m) {
[](const std::vector<at::Tensor>& tensors) -> std::string {
return GetTensorsHloGraph(tensors);
});
m.def("_get_xla_tensor_debug_info",
[](const at::Tensor& tensor) -> std::string {
return GetXLATensorDebugInfo(tensor);
});
py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>(
m, "XlaShardingSpec")
.def(py::init([](at::Tensor tensor, const py::list& tile_assignment,
Expand Down