Skip to content

Torch.onnx.export, RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #72175

@chipkajb

Description

@chipkajb

🐛 Describe the bug

I'm having trouble converting a simple PyTorch model to ONNX with FP16 precision.
I'm using the following command:

    torch.onnx.export(
        model,
        input,
        "test.onnx",
        input_names=["input"],
        output_names=["output"],
        export_params=True,
    )

The expected result is a successful export to ONNX after converting both the model and example input to FP16 and putting them on GPU (model.cuda(), model.half(), etc.). I get this expected result when testing the sample resnet model.

import torchvision.models as models
model = models.resnet18()

But when I try it with a simple custom-built PyTorch model, it fails with the following error:

File "/home/{USER}/.local/lib/python3.8/site-packages/torch/onnx/utils.py", line 628, in _model_to_graph
    params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

When I inspect graph in line 628 of site-packages/torch/onnx/utils.py, I see that there appears to be a couple onnx::Constant tensors that did not make it onto GPU. See the full graph below (notice lines %9 and %10).

graph(%input : Half(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cuda:0),
      %features.0.weight : Half(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cuda:0),
      %features.0.bias : Half(32, strides=[1], requires_grad=1, device=cuda:0),
      %features.1.running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0),
      %features.1.running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0),
      %features.1.num_batches_tracked : Long(requires_grad=0, device=cuda:0),
      %classifier.0.weight : Half(10, 8192, strides=[8192, 1], requires_grad=1, device=cuda:0),
      %classifier.0.bias : Half(10, strides=[1], requires_grad=1, device=cuda:0)):
  %8 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%input, %features.0.weight, %features.0.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:490:0
  %9 : Half(32, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %10 : Half(32, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %11 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%8, %9, %10, %features.1.running_mean, %features.1.running_var) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %12 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::Relu(%11) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1169:0
  %13 : Half(1, 32, 16, 16, strides=[8192, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%12) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:719:0
  %14 : Half(1, 8192, strides=[8192, 1], requires_grad=1, device=cuda:0) = onnx::Flatten[axis=1](%13) # toy_example.py:26:0
  %15 : Half(1, 10, strides=[10, 1], requires_grad=1, device=cuda:0) = onnx::Gemm[alpha=1., beta=1., transB=1](%14, %classifier.0.weight, %classifier.0.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1848:0
  %output : Half(1, 10, strides=[10, 1], requires_grad=1, device=cuda:0) = onnx::Softmax[axis=1](%15) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1680:0
  return (%output)

I brought this up in PyTorch Forum here, but didn't have much luck. Could you advise me if I'm doing something wrong on my end?
A simplified version of my full code is below.

import torch
import torch.nn as nn

class SVHN(nn.Module):
    def __init__(self):
        super().__init__()
        conv = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        bn = nn.BatchNorm2d(32, affine=False)
        relu = nn.ReLU()
        dropout = nn.Dropout(0.3)
        maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.features = nn.Sequential(
            conv,
            bn,
            relu,
            dropout,
            maxpool,
        )
        self.classifier = nn.Sequential(nn.Linear(8192, 10))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        x = self.softmax(x)
        return x

# load model
model = SVHN()
model = model.eval()
model = model.cuda()
model = model.half()

# create sample input
input = torch.Tensor(1, 3, 32, 32)
input = input.cuda()
input = input.half()

# attempt export to ONNX
torch.onnx.export(
    model,
    input,
    "test.onnx",
    input_names=["input"],
    output_names=["output"],
    export_params=True,
)

Versions

Collecting environment information...
PyTorch version: 1.10.2+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-97-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.3.58
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 465.19.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.1
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.0
[pip3] pytorch-ranger==0.1.1
[pip3] torch==1.10.2+cu113
[pip3] torch-optimizer==0.3.0
[pip3] torch-tensorrt==1.0.0
[pip3] torchaudio==0.10.2+cu113
[pip3] torchfile==0.1.0
[pip3] torchvision==0.11.3+cu113
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.21.0 pypi_0 pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: onnxRelated to torch.onnxonnx-triagedtriaged by ONNX teamtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions