Skip to content

Unexpected f64 div #4574

@ymwangg

Description

@ymwangg

🐛 Bug

We observed unexpected f64 div in the backward pass of nn.MultiheadAttention.

To Reproduce

import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn

device = xm.xla_device()
mod = nn.MultiheadAttention(768, 12, batch_first=True)
mod.to(device)

x = torch.rand(24,512,768).to(device)
y,_ = mod(x,x,x,need_weights=False)
y.sum().backward()

print(torch_xla._XLAC._get_xla_tensors_text([y]))
print(torch_xla._XLAC._get_xla_tensors_text([p.grad for p in mod.parameters() if p.requires_grad]))

In the forward pass IR:

%31 = f32[] xla::device_data(), xla_shape=f32[], device=GPU:0
%35 = f32[24,12,512,64]{3,2,1,0} aten::div(%34, %31), xla_shape=f32[24,12,512,64]{3,2,1,0}

In the backward pass IR

%73 = f64[] xla::device_data(), xla_shape=f64[], device=GPU:0
%78 = f64[24,12,512,64]{3,2,1,0} aten::div(%77, %73), xla_shape=f64[24,12,512,64]{3,2,1,0}

One interesting I noticed is the aten::div in the forward pass is through the at::Scalar& other signature:

at::Tensor XLANativeFunctions::div(const at::Tensor& self,
                                   const at::Scalar& other) {

while the aten::div in the backward pass is through the at::Tensor& other signature:

at::Tensor XLANativeFunctions::div(
    const at::Tensor& self, const at::Tensor& other,
    c10::optional<c10::string_view> rounding_mode) {

and in both cases other.scalar_type() is f64.

Expected behavior

There should not be f64 div as it's very costly on some GPUs.

Environment

  • Reproducible on XLA backend [CPU/TPU]: GPU
  • torch_xla version: master

Additional context

The div is coming from here.

I'm not sure if this is an issue of pytorch or it should be fixed in torch_xla.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions