Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

"upsample_nearest2d_out_frame" not implemented for 'BFloat16' #86679

Closed
patil-suraj opened this issue Oct 11, 2022 · 6 comments
Closed

"upsample_nearest2d_out_frame" not implemented for 'BFloat16' #86679

patil-suraj opened this issue Oct 11, 2022 · 6 comments
Labels
module: bfloat16 module: interpolation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@patil-suraj
Copy link

馃悰 Describe the bug

Nearest upsampling with torch.nn.functional.interpolate does not work in bfloat16. Minimal code to reproduce.

import torch
import torch.nn.functional as F

image = torch.randn(1, 4, 32, 32).to(device="cuda", dtype=torch.bfloat16)
out = F.interpolate(image, size=(64, 64), mode="nearest")

This throws an error

File ~/.pyenv/versions/3.9.14/envs/diffusers-env/lib/python3.9/site-packages/torch/nn/functional.py:3910, in interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
   3908     return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
   3909 if input.dim() == 4 and mode == "nearest":
-> 3910     return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
   3911 if input.dim() == 5 and mode == "nearest":
   3912     return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)

RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'BFloat16'

F.interpolate with nearest mode is used a lot unets which are the backbone diffusion models like stable diffusion. Due to this at the moment it's not possible to use Stable Diffusion with bfloat16 without manual casting. cf huggingface/diffusers#792

Versions

PyTorch version: 1.12.1+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 8.3.0-6) 8.3.0
Clang version: Could not collect
CMake version: version 3.13.4
Libc version: glibc-2.28

Python version: 3.9.14 (main, Sep 22 2022, 15:50:51)  [GCC 8.3.0] (64-bit runtime)
Python platform: Linux-4.19.0-22-cloud-amd64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: 11.0.221
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB

Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.3
[pip3] torch==1.12.1+cu116
[pip3] torchaudio==0.12.1+cu116
[pip3] torchvision==0.13.1+cu116
[conda] numpy                     1.19.5           py37h3e96413_3    conda-forge
@samdow samdow added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: bfloat16 module: interpolation labels Oct 12, 2022
@ptrblck
Copy link
Collaborator

ptrblck commented Nov 1, 2022

#80340 seems to be related

@noamsgl
Copy link
Contributor

noamsgl commented Mar 12, 2023

Same problem when using nn.Upsample(scale_factor=(4, 1))(x):

return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'BFloat16'

@coderpiaobozhe
Copy link

Same problem when using F.interpolate function. Is there any solution to solve this issue?

@bibhabasumohapatra
Copy link

@patil-suraj any solution?

vladmandic referenced this issue in vladmandic/automatic Jun 3, 2023
@dsuess
Copy link

dsuess commented Jul 27, 2023

@bibhabasumohapatra see #88536

@d4l3k
Copy link
Collaborator

d4l3k commented Oct 29, 2023

@dsuess this is fixed in PT2.1

@d4l3k d4l3k closed this as completed Oct 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: bfloat16 module: interpolation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants