Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grid sampler op need to register fp32 autocast #7305

Open
lingzhi98 opened this issue Jun 18, 2024 · 4 comments
Open

grid sampler op need to register fp32 autocast #7305

lingzhi98 opened this issue Jun 18, 2024 · 4 comments
Assignees

Comments

@lingzhi98
Copy link
Contributor

lingzhi98 commented Jun 18, 2024

🐛 Bug

grid_sampler can not run with auto mixed precision mode.

Steps to reproduce the behavior:

import torch
import torch.nn.functional as F
import numpy as np
import torch_xla
import torch_xla.core.xla_model as xm

xla_device = xm.xla_device()

sz = 5
input_arr = torch.from_numpy(np.arange(sz * sz).reshape(1, 1, sz, sz)).to(xla_device, dtype=torch.bfloat16)
indices = torch.from_numpy(np.array([-1, -1, -0.5, -0.5, 0,0, 0.5, 0.5, 1,1]).reshape(1, 1, 5, 2)).to(xla_device, dtype=torch.bfloat16)

with torch.amp.autocast("xla", dtype=torch.bfloat16):
  out = F.grid_sample(input_arr, indices)
xm.mark_step()
print(input_arr)
print(out)

RuntimeError: grid_sampler_2d_cpu not implemented for BFloat16

Expected behavior

Autocast the inputs of grid_sampler to fp32 datatype.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU / Other PJRT device (intel gpu) / I think TPU has this issue also

Additional context

torch.cuda set autocast mode of grid_sampler as promote, but torch xla can't. Due to torch_xla has no lowering of grid_sampler_2d/grid_sampler_3d, this op will fallback to torch.cpu implementation and no support of bfloat16. Maybe we should set autocast mode as fp32 firstly, and change it to promote until lowering is ready. By the way, grid sampler has no need of lowering, due to this op do nothing but only dispatch to gird_sampler_2d, grid_sampler_3d and cudnn_grid_sampler. https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GridSampler.cpp#L1046

@JackCaoG
Copy link
Collaborator

yea I confirmed that we don't have a lowering for

Counter: aten::grid_sampler_2d
  Value: 1

I guess the solution here would be to actually lower this op.

@JackCaoG
Copy link
Collaborator

also found #6581

@lingzhi98
Copy link
Contributor Author

Lowering grid sampler 2d is not enough, we should lower grid sampler 3d also to fully support grid sampler op. That's why I suggest set autocast mode as fp32 firstly. Or there exists plan to support grid sampler 3d? I dont find it until now.

@ManfeiBai ManfeiBai assigned ManfeiBai and wonjoolee95 and unassigned ManfeiBai Jun 25, 2024
@ManfeiBai
Copy link
Collaborator

Hi, @wonjoolee95, is that ok to assign this ticket to you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants