Skip to content

torch.jit.trace crashed by device mismatch #62961

@xin-w8023

Description

@xin-w8023

🐛 Bug

I found that torch.jit.trace will remember the tensor's device during the tracing process, if we use the different device to inference, it is crashed.

To Reproduce

Here is the reproducing code.

import torch


class Net(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.li = torch.nn.Linear(10, 10)

  def forward(self, x):
    y = x + torch.randn(1).to(x.device)
    return self.li(y)

device = "cpu"

net = Net().to(device)

inputs = torch.randn(2, 10).to(device)
traced_model = torch.jit.trace(net, inputs, check_trace=False)
ret1 = traced_model(inputs)
traced_model.save('traced_model.cpt')

device = 'cuda'
model = torch.jit.load('traced_model.cpt').to(device)
inputs = inputs.to(device)
ret2 = model(inputs)
Traceback (most recent call last):
  File "test.py", line 25, in <module>
    ret2 = model(inputs)
  File "~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__.py", line 15, in forward
    _4 = torch.randn([_2, int(_3)], dtype=6, layout=None, device=torch.device("cpu"), pin_memory=False)
    _5 = torch.to(_4, dtype=6, layout=0, device=torch.device("cpu"), pin_memory=None, non_blocking=False, copy=False, memory_format=None)
    input = torch.add(x, _5, alpha=1)
            ~~~~~~~~~ <--- HERE
    return (_0).forward(input, )

Traceback of TorchScript, original code (most recent call last):
test.py(10): forward
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(860): _slow_forward
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(887): _call_impl
~/anaconda3/lib/python3.7/site-packages/torch/jit/_trace.py(940): trace_module
~/anaconda3/lib/python3.7/site-packages/torch/jit/_trace.py(742): trace
test.py(18): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Expected behavior

Expected running well.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

Collecting environment information...
PyTorch version: 1.8.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 9.13 (stretch) (x86_64)
GCC version: (Debian 6.3.0-18+deb9u1) 6.3.0 20170516
Clang version: Could not collect
CMake version: version 3.19.4
Libc version: glibc-2.9

Python version: 3.7.0 (default, Jun 28 2018, 13:15:42) [GCC 7.2.0] (64-bit runtime)
Python platform: Linux-4.14.81.bm.26-amd64-x86_64-with-debian-9.13
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Tesla T4
GPU 1: Tesla T4
GPU 2: Tesla T4
GPU 3: Tesla T4
GPU 4: Tesla T4
GPU 5: Tesla T4
GPU 6: Tesla T4
GPU 7: Tesla T4

Nvidia driver version: 440.118.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.0
[pip3] numpydoc==0.8.0
[pip3] pytorch-wpe==0.0.1
[pip3] torch==1.8.1+cu102
[pip3] torch-complex==0.2.1
[pip3] torchaudio==0.8.1
[pip3] torchvision==0.9.1+cu102
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.0.221 h6bb024c_0
[conda] mkl 2019.0 118
[conda] mkl-service 1.1.2 py37h90e4bf4_5
[conda] mkl_fft 1.0.4 py37h4414c95_1
[conda] mkl_random 1.0.1 py37h4414c95_1
[conda] numpy 1.15.1 py37h1d66e8a_0
[conda] numpy 1.21.0
[conda] numpy-base 1.15.1 py37h81de0dd_0
[conda] numpydoc 0.8.0 py37_0
[conda] pytorch-wpe 0.0.1
[conda] torch 1.8.1+cu102
[conda] torch-complex 0.2.1
[conda] torchaudio 0.8.1
[conda] torchvision 0.9.1+cu102

Additional context

cc @gmagogsfm

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions