📚 Documentation
If we write a loop like
start = time.time()
for step in range(num_steps):
run_model()
xm.mark_step()
end = time.time()
Then end - start will only measure the tracing time. We'll need to do torch_xla.sync(wait=True) to block on device execution to measure the execution time.
We should document this in some "common FAQs/sharp edges" maybe