device = xm.xla_device()
results = torch.randn((20, 20), requires_grad=True).to(device)
torch.argmax(results, 1)
RuntimeError: max(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.
This does not reproduce with CPU or CUDA. You can quickly reproduce it on Colab.