-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Bug
I found that torch.jit.trace will remember the tensor's device during the tracing process, if we use the different device to inference, it is crashed.
To Reproduce
Here is the reproducing code.
import torch
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.li = torch.nn.Linear(10, 10)
def forward(self, x):
y = x + torch.randn(1).to(x.device)
return self.li(y)
device = "cpu"
net = Net().to(device)
inputs = torch.randn(2, 10).to(device)
traced_model = torch.jit.trace(net, inputs, check_trace=False)
ret1 = traced_model(inputs)
traced_model.save('traced_model.cpt')
device = 'cuda'
model = torch.jit.load('traced_model.cpt').to(device)
inputs = inputs.to(device)
ret2 = model(inputs)Traceback (most recent call last):
File "test.py", line 25, in <module>
ret2 = model(inputs)
File "~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__.py", line 15, in forward
_4 = torch.randn([_2, int(_3)], dtype=6, layout=None, device=torch.device("cpu"), pin_memory=False)
_5 = torch.to(_4, dtype=6, layout=0, device=torch.device("cpu"), pin_memory=None, non_blocking=False, copy=False, memory_format=None)
input = torch.add(x, _5, alpha=1)
~~~~~~~~~ <--- HERE
return (_0).forward(input, )
Traceback of TorchScript, original code (most recent call last):
test.py(10): forward
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(860): _slow_forward
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(887): _call_impl
~/anaconda3/lib/python3.7/site-packages/torch/jit/_trace.py(940): trace_module
~/anaconda3/lib/python3.7/site-packages/torch/jit/_trace.py(742): trace
test.py(18): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!Expected behavior
Expected running well.
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
Collecting environment information...
PyTorch version: 1.8.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 9.13 (stretch) (x86_64)
GCC version: (Debian 6.3.0-18+deb9u1) 6.3.0 20170516
Clang version: Could not collect
CMake version: version 3.19.4
Libc version: glibc-2.9
Python version: 3.7.0 (default, Jun 28 2018, 13:15:42) [GCC 7.2.0] (64-bit runtime)
Python platform: Linux-4.14.81.bm.26-amd64-x86_64-with-debian-9.13
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Tesla T4
GPU 1: Tesla T4
GPU 2: Tesla T4
GPU 3: Tesla T4
GPU 4: Tesla T4
GPU 5: Tesla T4
GPU 6: Tesla T4
GPU 7: Tesla T4
Nvidia driver version: 440.118.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.21.0
[pip3] numpydoc==0.8.0
[pip3] pytorch-wpe==0.0.1
[pip3] torch==1.8.1+cu102
[pip3] torch-complex==0.2.1
[pip3] torchaudio==0.8.1
[pip3] torchvision==0.9.1+cu102
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.0.221 h6bb024c_0
[conda] mkl 2019.0 118
[conda] mkl-service 1.1.2 py37h90e4bf4_5
[conda] mkl_fft 1.0.4 py37h4414c95_1
[conda] mkl_random 1.0.1 py37h4414c95_1
[conda] numpy 1.15.1 py37h1d66e8a_0
[conda] numpy 1.21.0
[conda] numpy-base 1.15.1 py37h81de0dd_0
[conda] numpydoc 0.8.0 py37_0
[conda] pytorch-wpe 0.0.1
[conda] torch 1.8.1+cu102
[conda] torch-complex 0.2.1
[conda] torchaudio 0.8.1
[conda] torchvision 0.9.1+cu102
Additional context
cc @gmagogsfm