In [1]:
import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

device = xm.xla_device()
t1 = torch.tensor(500, device=device)
t2 = torch.tensor(600, device=device)
t3 = t1 + t2
t4 = t3 * t2

# Print IR of t3. It's a bunch of aten ops.
print(torch_xla._XLAC._get_xla_tensors_text([t3]))
# Print IR of t4. It's a bunch of aten ops.
print(torch_xla._XLAC._get_xla_tensors_text([t4]))

IR {
  %0 = s64[] prim::Constant(), xla_shape=s64[]
  %1 = s64[] xla::device_data(), xla_shape=s64[]
  %2 = s64[] xla::device_data(), xla_shape=s64[]
  %3 = s64[] aten::add(%2, %1, %0), xla_shape=s64[], ROOT=0
}

IR {
  %0 = s64[] xla::device_data(), xla_shape=s64[]
  %1 = s64[] prim::Constant(), xla_shape=s64[]
  %2 = s64[] xla::device_data(), xla_shape=s64[]
  %3 = s64[] aten::add(%2, %0, %1), xla_shape=s64[]
  %4 = s64[] aten::mul(%3, %0), xla_shape=s64[], ROOT=0
}



In [2]:
# Print execution count (0)
print(met.short_metrics_report(["ExecuteComputation"], ["ExecuteComputation"]))

# Synchronously evaluate t3.
print(t3)

# Print execution count (1)
print(met.short_metrics_report(["ExecuteComputation"], ["ExecuteComputation"]))

# Print IR of t3. It now gives a single `xla::device_data()`
print(torch_xla._XLAC._get_xla_tensors_text([t3]))

# Evaluate t3 again. It shouldn't recompute.
for _ in range(10):
    print(t3)

# Print execution count (1)
print(met.short_metrics_report(["ExecuteComputation"], ["ExecuteComputation"]))


tensor(1100, device='xla:0')
Counter: ExecuteComputation
  Value: 1

IR {
  %0 = s64[] xla::device_data(), xla_shape=s64[], ROOT=0
}

tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
tensor(1100, device='xla:0')
Counter: ExecuteComputation
  Value: 1



In [3]:
# Print IR of t4. It is still a bunch of aten ops, unchanged.
print(torch_xla._XLAC._get_xla_tensors_text([t4]))


IR {
  %0 = s64[] xla::device_data(), xla_shape=s64[]
  %1 = s64[] prim::Constant(), xla_shape=s64[]
  %2 = s64[] xla::device_data(), xla_shape=s64[]
  %3 = s64[] aten::add(%2, %0, %1), xla_shape=s64[]
  %4 = s64[] aten::mul(%3, %0), xla_shape=s64[], ROOT=0
}



In [4]:
import torch_xla.debug.metrics as met
report = met.metrics_report()
print(report)

Metric: DeviceLockWait
  TotalSamples: 2
  Accumulator: 042.891us
  ValueRate: 03s681ms687.500us / second
  Rate: 125000 / second
  Percentiles: 1%=008.231us; 5%=008.231us; 10%=008.231us; 20%=008.231us; 50%=034.660us; 80%=034.660us; 90%=034.660us; 95%=034.660us; 99%=034.660us
Metric: LazyTracing
  TotalSamples: 21
  Accumulator: 107ms121.189us
  ValueRate: 854ms274.856us / second
  Rate: 167.472 / second
  Percentiles: 1%=021.440us; 5%=033.420us; 10%=149.160us; 20%=445.780us; 50%=593.730us; 80%=948.209us; 90%=003ms769.310us; 95%=004ms861.050us; 99%=091ms950.293us
Metric: TensorToData
  TotalSamples: 2
  Accumulator: 480.390us
  ValueRate: 558ms705.234us / second
  Rate: 2321.89 / second
  Percentiles: 1%=165.140us; 5%=165.140us; 10%=165.140us; 20%=165.140us; 50%=315.250us; 80%=315.250us; 90%=315.250us; 95%=315.250us; 99%=315.250us
Metric: TensorsGraphSize
  TotalSamples: 1
  Accumulator: 4.00
  Percentiles: 1%=4.00; 5%=4.00; 10%=4.00; 20%=4.00; 50%=4.00; 80%=4.00; 90%=4.00; 95%=4.00; 9

In [7]:
torch_xla._XLAC._xla_counter_names()
torch_xla._XLAC._xla_counter_value("ExecuteComputation")

1