Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,25 @@ def test_metrics_report(self):
self.assertIn("InputOutputAliasCount", metric_names)
self.assertNotEqual(met.metric_data("InputOutputAliasCount"), None)

met.clear_metrics()
self.assertNotIn("InputOutputAliasCount", met.metric_names())
self.assertEqual(met.metric_data("InputOutputAliasCount"), None)

# timed metrics
self.assertIn("TensorToData", report)
self.assertIn("UnwrapXlaData", report)
self.assertIn("WrapXlaData", report)
self.assertIn("DeviceLockWait", report)
self.assertIn("TensorToData", metric_names)
self.assertNotEqual(met.metric_data("TensorToData"), None)
self.assertIn("UnwrapXlaData", metric_names)
self.assertNotEqual(met.metric_data("UnwrapXlaData"), None)
self.assertIn("WrapXlaData", metric_names)
self.assertNotEqual(met.metric_data("WrapXlaData"), None)
self.assertIn("DeviceLockWait", metric_names)
self.assertNotEqual(met.metric_data("DeviceLockWait"), None)

met.clear_metrics()
self.assertNotIn("InputOutputAliasCount", met.metric_names())
self.assertEqual(met.metric_data("InputOutputAliasCount"), None)
self.assertNotIn("TensorToData", met.metric_names())
self.assertEqual(met.metric_data("TensorToData"), None)

# repeat the same computation and expect to see the CachedCompile counter
t3 = t1 * 2
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla_client/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class TimedSection {
int64_t start_;
};

// XLA_TIMED should only be used within xla_client. Please use
// TORCH_LAZY_TIMED in pytorch/xla. For more information, see
// NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
#define XLA_TIMED(name) \
static xla::metrics::Metric* timed_metric = \
new xla::metrics::Metric(name, xla::metrics::MetricFnTime); \
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ void InitXlaModuleBindings(py::module m) {
std::vector<torch::lazy::BackendDataPtr> parameters_data;
torch::lazy::BackendDevice device = torch_xla::GetCurrentDevice();
{
XLA_TIMED("RunCachedGraphInputData");
TORCH_LAZY_TIMED("RunCachedGraphInputData");
// setup the parameters_data
int idx = 0;
for (auto& ivalue : graph_inputs) {
Expand All @@ -1688,7 +1688,7 @@ void InitXlaModuleBindings(py::module m) {
cachedComputation->computation, parameters_data, device);
std::vector<at::Tensor> retlist;
{
XLA_TIMED("RunCachedGraphOutputData");
TORCH_LAZY_TIMED("RunCachedGraphOutputData");
// Convert result back to at::tensor
int i = 0;
for (auto& data : results) {
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ xla::util::ExceptionCleanup LockDevice(
TF_VLOG(4) << "Waiting on device barrier for device " << device << " ...";
std::shared_ptr<DeviceLocker> locker;
{
XLA_TIMED("DeviceLockWait");
TORCH_LAZY_TIMED("DeviceLockWait");
locker = DeviceLockerArena::Get()->GetLocker(device);
locker->Lock();
}
Expand Down Expand Up @@ -806,7 +806,7 @@ torch::lazy::Value XLATensor::GetIrValueForTensor(
data = GetDeviceData(tensor, device);
read_only = true;
} else {
XLA_TIMED("IrValueTensorToXlaData");
TORCH_LAZY_TIMED("IrValueTensorToXlaData");
data = TensorToXlaData(tensor, device);
}
return CreateTensorNode(std::move(data), read_only);
Expand Down
16 changes: 8 additions & 8 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct DataAsync {

void TransferToServerAsync(std::shared_ptr<DataAsync> async,
const std::vector<std::string>& devices) {
XLA_TIMED("TransferToServerAsync");
TORCH_LAZY_TIMED("TransferToServerAsync");

std::vector<xla::ComputationClient::DataPtr> async_xla_datas =
xla::ComputationClient::Get()->CreateAsyncDatas(async->source_tensors);
Expand Down Expand Up @@ -652,7 +652,7 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
torch::lazy::BackendDataPtr TensorToXlaData(
const at::Tensor& tensor, const xla::Shape& shape,
const torch::lazy::BackendDevice& device) {
XLA_TIMED("TensorToData");
TORCH_LAZY_TIMED("TensorToData");
if (device.type() == (int8_t)XlaDeviceType::SPMD) {
// When SPMD is enabled, we want to delay the data transfer for XLA
// tensors until the data is sharded. So, we skip the data transfer
Expand Down Expand Up @@ -763,13 +763,13 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,

xla::ComputationClient::DataPtr UnwrapXlaData(
const torch::lazy::BackendDataPtr& data) {
XLA_TIMED("UnwrapXlaData");
TORCH_LAZY_TIMED("UnwrapXlaData");
return dynamic_cast<XLAData*>(data.get())->xla_data();
}

std::vector<xla::ComputationClient::DataPtr> UnwrapXlaData(
absl::Span<const torch::lazy::BackendDataPtr> datas) {
XLA_TIMED("UnwrapXlaData");
TORCH_LAZY_TIMED("UnwrapXlaData");
std::vector<xla::ComputationClient::DataPtr> xla_datas;
xla_datas.reserve(datas.size());
for (const auto& data : datas) {
Expand All @@ -780,13 +780,13 @@ std::vector<xla::ComputationClient::DataPtr> UnwrapXlaData(

torch::lazy::BackendDataPtr WrapXlaData(
const xla::ComputationClient::DataPtr& xla_data) {
XLA_TIMED("WrapXlaData");
TORCH_LAZY_TIMED("WrapXlaData");
return std::make_shared<XLAData>(xla_data);
}

std::vector<torch::lazy::BackendDataPtr> WrapXlaData(
absl::Span<const xla::ComputationClient::DataPtr> xla_datas) {
XLA_TIMED("WrapXlaData");
TORCH_LAZY_TIMED("WrapXlaData");
std::vector<torch::lazy::BackendDataPtr> datas;
datas.reserve(xla_datas.size());
for (const auto& xla_data : xla_datas) {
Expand Down Expand Up @@ -868,7 +868,7 @@ torch::lazy::BackendDataPtr TensorToXlaData(
std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices, bool transfer_async) {
XLA_TIMED("TensorToData");
TORCH_LAZY_TIMED("TensorToData");
XLA_CHECK_EQ(tensors.size(), devices.size());
if (transfer_async) {
std::shared_ptr<DataAsync> async = std::make_shared<DataAsync>();
Expand Down Expand Up @@ -918,7 +918,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
const std::vector<at::Tensor>& tensors,
const std::vector<XLATensor::ShardingSpecPtr>& shardings,
const std::vector<std::string>& devices) {
XLA_TIMED("TensorToData");
TORCH_LAZY_TIMED("TensorToData");
XLA_CHECK_EQ(tensors.size(), shardings.size());
XLA_CHECK_EQ(tensors.size(), devices.size());

Expand Down