diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a9f34dd7914d..bc4f7534d89a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1451,6 +1451,7 @@ void InitXlaModuleBindings(py::module m) { m.def("_xla_mark_sharding", [](const at::Tensor& input, const py::list& tile_assignment, bool replicated = false, bool manual = false) { + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); xla::OpSharding sharding = ShardingUtil::CreateOpSharding(tile_assignment, replicated, manual); XLATensorPtr xtensor = bridge::GetXlaTensor(input); @@ -1464,6 +1465,7 @@ void InitXlaModuleBindings(py::module m) { at::Tensor cpu_tensor; if (xla::sys_util::GetEnvBool("XLA_USE_SPMD", false) && xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); // When virtual device is enabled for SPMD, we defer the initial data // transfer to the device and retain the original data on the host, until // the sharded data transfer.