-
Notifications
You must be signed in to change notification settings - Fork 560
Closed
Description
🐛 Bug
We observed unexpected f64 div in the backward pass of nn.MultiheadAttention.
To Reproduce
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn
device = xm.xla_device()
mod = nn.MultiheadAttention(768, 12, batch_first=True)
mod.to(device)
x = torch.rand(24,512,768).to(device)
y,_ = mod(x,x,x,need_weights=False)
y.sum().backward()
print(torch_xla._XLAC._get_xla_tensors_text([y]))
print(torch_xla._XLAC._get_xla_tensors_text([p.grad for p in mod.parameters() if p.requires_grad]))In the forward pass IR:
%31 = f32[] xla::device_data(), xla_shape=f32[], device=GPU:0
%35 = f32[24,12,512,64]{3,2,1,0} aten::div(%34, %31), xla_shape=f32[24,12,512,64]{3,2,1,0}In the backward pass IR
%73 = f64[] xla::device_data(), xla_shape=f64[], device=GPU:0
%78 = f64[24,12,512,64]{3,2,1,0} aten::div(%77, %73), xla_shape=f64[24,12,512,64]{3,2,1,0}
One interesting I noticed is the aten::div in the forward pass is through the at::Scalar& other signature:
at::Tensor XLANativeFunctions::div(const at::Tensor& self,
const at::Scalar& other) {
while the aten::div in the backward pass is through the at::Tensor& other signature:
at::Tensor XLANativeFunctions::div(
const at::Tensor& self, const at::Tensor& other,
c10::optional<c10::string_view> rounding_mode) {
and in both cases other.scalar_type() is f64.
Expected behavior
There should not be f64 div as it's very costly on some GPUs.
Environment
- Reproducible on XLA backend [CPU/TPU]: GPU
- torch_xla version: master
Additional context
The div is coming from here.
I'm not sure if this is an issue of pytorch or it should be fixed in torch_xla.
Metadata
Metadata
Assignees
Labels
No labels