diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index eedec835310..aac0abcf464 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -95,7 +95,7 @@ If your model shows bad performance, keep in mind the following caveats: - For most ops we can lower them to XLA to fix it. Checkout [metrics report section](#metrics-report) to find out the missing ops and open a feature request on [GitHub](https://github.com/pytorch/xla/issues). - Even when a PyTorch tensor is known as a scalar, avoid using `tensor.item()`. Keep it as a tensor and use tensor operations on it. - Use `torch.where` to substitute control flow when applicable. - E.g. The control flow with `item()` used in [clip_grad_norm_](https://github.com/pytorch/pytorch/blob/de19eeee99a2a282fc441f637b23d8e50c75ecd1/torch/nn/utils/clip_grad.py#L33) can be simply replaced by `torch.where` with dramatical performance improvement. + E.g. The control flow with `item()` used in [clip_grad_norm_](https://github.com/pytorch/pytorch/blob/de19eeee99a2a282fc441f637b23d8e50c75ecd1/torch/nn/utils/clip_grad.py#L33) is problematic and impacts performance, so we have [patched](https://github.com/pytorch/xla/blob/master/torch_patches/X10-clip_grad.diff) `clip_grad_norm_` by calling `torch.where` instead, which gives us a dramatic performance improvement. ```python ... else: