Skip to content

Commit 1a67542

Browse files
committed
Add test to make sure print tensor only execute graph once
1 parent e7af313 commit 1a67542

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

test/test_operations.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,6 +1649,34 @@ def test_cached_addcdiv(self):
16491649
xm.mark_step()
16501650
self.assertEqual(met.metric_data("TransferToServerTime")[0], 4)
16511651

1652+
def test_print_executation(self):
1653+
xla_device = xm.xla_device()
1654+
xm.mark_step()
1655+
met.clear_all()
1656+
1657+
# case 1 mark_step
1658+
t1 = torch.randn(1, 4, device=xla_device)
1659+
xm.mark_step()
1660+
xm.wait_device_ops()
1661+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
1662+
for _ in range(3):
1663+
print(t1.cpu())
1664+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
1665+
1666+
# case 2 no mark_step, directly print
1667+
met.clear_all()
1668+
t1 = torch.randn(1, 4, device=xla_device)
1669+
for _ in range(3):
1670+
print(t1)
1671+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
1672+
1673+
# case 2 no mark_step, print with .cpu
1674+
met.clear_all()
1675+
t1 = torch.randn(1, 4, device=xla_device)
1676+
for _ in range(3):
1677+
print(t1.cpu())
1678+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
1679+
16521680
def test_index_types(self):
16531681

16541682
def test_fn(*indices):

0 commit comments

Comments
 (0)