Skip to content

TorchScript model doesn't work with autocast #140279

Open
@ktsumura

Description

@ktsumura

🐛 Describe the bug

Hi I want to use autocast with a script model and had the following error.

import torch
import torch.nn as nn
from torch.amp import autocast

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)

    def forward(self, x):
        with autocast(device_type="cuda", enabled=True):
            return self.conv1(x)

device = torch.device('cuda')

# Create an instance of the network
net = SimpleCNN()
net.to(device)

# Create a sample input tensor
input_tensor = torch.randn(1, 3, 28, 28)
input_tensor = input_tensor.to(device)

# Pass the input through the network
output = net(input_tensor)
print(output.shape)

# Pass the input through the script network
script_net = torch.jit.script(net)
output2 = script_net(input_tensor)
print(output2.shape)
torch.Size([1, 6, 24, 24])
Traceback (most recent call last):
  File "D:\workspace_tf\Einstein_reg_8\src\neosoft\misc\SimpleCNN.py", line 39, in <module>
    output2 = script_net(input_tensor)
  File "C:\Python39\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Python39\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "D:\workspace_tf\Einstein_reg_8\src\neosoft\misc\SimpleCNN.py", line 21, in forward
            # x = torch.relu(self.fc1(x))
            # return x
            return self.conv1(x)
                   ~~~~~~~~~~ <--- HERE
  File "C:\Python39\lib\site-packages\torch\nn\modules\conv.py", line 554, in forward
    def forward(self, input: Tensor) -> Tensor:
        return self._conv_forward(input, self.weight, self.bias)
               ~~~~~~~~~~~~~~~~~~ <--- HERE
  File "C:\Python39\lib\site-packages\torch\nn\modules\conv.py", line 549, in _conv_forward
                self.groups,
            )
        return F.conv2d(
               ~~~~~~~~ <--- HERE
            input, weight, bias, self.stride, self.padding, self.dilation, self.groups
        )
RuntimeError: Input type (struct c10::Half) and bias type (float) should be the same

Versions

Collecting environment information...
PyTorch version: 2.5.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Pro (10.0.22631 64-bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.7 (tags/v3.9.7:1016ef3, Aug 30 2021, 20:19:38) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22631-SP0
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000

Nvidia driver version: 528.02
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\cudnn_ops_train64_8.dll
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: Intel(R) Core(TM) i9-10900F CPU @ 2.80GHz
Manufacturer: GenuineIntel
Family: 207
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 2808
MaxClockSpeed: 2808
L2CacheSize: 2560
L2CacheSpeed: None
Revision: None

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] onnx==1.14.0
[pip3] onnxruntime-gpu==1.16.3
[pip3] pytorch-lightning==2.2.4
[pip3] tf2onnx==1.15.1
[pip3] torch==2.5.1+cu118
[pip3] torch-dct==0.1.6
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.5.1+cu118
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==1.4.0
[pip3] torchvision==0.20.1+cu118
[conda] Could not collect

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions