diff --git a/test/test_operations.py b/test/test_operations.py index b36700bae8e1..e6fe22330607 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1757,6 +1757,35 @@ def test_wait_device_ops(self): "ExecuteChainedTime" in met.metric_names()) +class TestDebuggingUtil(test_utils.XlaTestCase): + + def test_get_xla_tensor_debug_info(self): + if xu.getenv_as('XLA_USE_EAGER_DEBUG_MODE', str, '1'): + # ignore this test for eager debug mode since it will + # mess up the IR. + return + device = xm.xla_device() + # test non xla tensor + cpu_t1 = torch.randn(5) + cpu_t1_info = torch_xla._XLAC._get_xla_tensor_debug_info(cpu_t1) + self.assertIn('Not a XLATensor', cpu_t1_info) + + # test a tensor with IR + t1 = cpu_t1.to(device) + t2 = t1 + 5 + t2_info = torch_xla._XLAC._get_xla_tensor_debug_info(t2) + self.assertIn('XLA Shape: f32[5]', t2_info) + self.assertIn('aten::add', t2_info) + self.assertIn('XLAData: None', t2_info) + + # after makr_step XLAData should present + xm.mark_step() + t2_info_new = torch_xla._XLAC._get_xla_tensor_debug_info(t2) + self.assertNotIn('XLAData: None', t2_info_new) + self.assertIn('Data Shape: f32[5]', t2_info_new) + self.assertIn('IR: None', t2_info_new) + + class TestOpBuilder(test_utils.XlaTestCase): def runOpBuilderTest(self, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e8c70a719ab4..6e501e997f15 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -361,6 +361,47 @@ std::string GetTensorsHloGraph(const std::vector& tensors, return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, get_stable_hlo); } +std::string GetXLATensorDebugInfo(const at::Tensor& tensor) { + auto xtensor = bridge::TryGetXlaTensor(tensor); + if (!xtensor) { + return "Not a XLATensor\n"; + } + std::stringstream ss; + ss << "XLATensor {\n"; + ss << "TensorID: " << xtensor->GetUniqueId() << "\n"; + ss << "Device: " << xtensor->GetDevice() << "\n"; + ss << "XLA Shape: " << xtensor->shape().get().ToString() << "\n"; + + torch::lazy::Value ir_value = xtensor->CurrentIrValue(); + ss << "IR: "; + if (ir_value) { + ss << ir_value.node->ToString() << "\n"; + } else { + ss << "None\n"; + } + + torch::lazy::BackendDataPtr handle = xtensor->CurrentDataHandle(); + ss << "XLAData: "; + if (handle) { + auto data = UnwrapXlaData(handle); + ss << "\n Data Device: " << data->device() << "\n"; + ss << " Data Shape: " << data->shape().ToString() << "\n"; + } else { + ss << "None\n"; + } + + auto at_tensor = xtensor->CurrentTensorData(); + ss << "Tensor on host: "; + if (at_tensor) { + ss << " with size " << at_tensor->sizes() << "\n"; + } else { + ss << "None\n"; + } + + ss << "}\n"; + return ss.str(); +} + std::string GetLiveTensorsReport(size_t nodes_threshold, const std::string& device_str) { auto opt_device = GetOptionalDevice(device_str); @@ -853,6 +894,10 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& tensors) -> std::string { return GetTensorsHloGraph(tensors); }); + m.def("_get_xla_tensor_debug_info", + [](const at::Tensor& tensor) -> std::string { + return GetXLATensorDebugInfo(tensor); + }); py::class_( m, "XlaShardingSpec") .def(py::init([](at::Tensor tensor, const py::list& tile_assignment,