-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
PyTorch doesn't allow serializing HPU (and a few other backends which include XLA) tensors of type BFloat16.
Example with XLA (similar issue is observed with HPU tensors):
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device).to(torch.bfloat16)
t1 = torch.randn(2, 2, device=device).to(torch.bfloat16)
tensors = (t0, t1)
torch.save(tensors, 'tensors.pt')
This example fails with -
TypeError Traceback (most recent call last)
in ()
10 tensors = (t0, t1)
11
---> 12 torch.save(tensors, 'tensors.pt')
3 frames
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py in _reduce_ex_internal(self, proto)
201 # and serialize them one by one.
202 if self.device.type in ['xla', 'ort', 'mlc']:
--> 203 return (torch._utils._rebuild_device_tensor_from_numpy, (self.cpu().numpy(),
204 self.dtype,
205 str(self.device),
TypeError: Got unsupported ScalarType BFloat16
Versions
PyTorch version: 1.11.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.22.4
Libc version: glibc-2.26
Python version: 3.7.13 (default, Apr 24 2022, 01:04:09) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.188+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: False
CUDA runtime version: 11.1.105
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.11.0+cu113
[pip3] torch-xla==1.11
[pip3] torchaudio==0.11.0+cu113
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.12.0
[pip3] torchvision==0.12.0+cu113
[conda] Could not collect