Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT vs. eager mismatches for jit.traced int8 to int32 casting #66930

Open
ptrblck opened this issue Oct 20, 2021 · 1 comment
Open

JIT vs. eager mismatches for jit.traced int8 to int32 casting #66930

ptrblck opened this issue Oct 20, 2021 · 1 comment
Labels
NNC oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@ptrblck
Copy link
Collaborator

ptrblck commented Oct 20, 2021

🐛 Bug

A traced model seems to create overflows when int8 tensor values are transformed to int32 on the GPU.

To Reproduce

import torch
import torch.nn as nn

class AddSubNet(nn.Module):
    def __init__(self, *args):
        self.torch_output0_dtype = args[0][0]
        self.torch_output1_dtype = args[0][1]
        super(AddSubNet, self).__init__()

    def forward(self, input0, input1):
        return (input0 + input1).to(self.torch_output0_dtype), \
               (input0 - input1).to(self.torch_output1_dtype)

device = 'cpu'
model = AddSubNet((torch.int32, torch.int32)).to(device)
x1 = torch.randint(-127, 128, (16,)).to(torch.int8).to(device)
x2 = torch.randint(-127, 128, (16,)).to(torch.int8).to(device)
model_cpu = torch.jit.trace(model, (x1, x2))

print('input')
print(x1)
print(x2)

out1_cpu, out2_cpu = model_cpu(x1, x2)
print('cpu output')
print(out1_cpu)
print(out2_cpu)

device = 'cuda'
model.to(device)
x1 = x1.to(device)
x2 = x2.to(device)

model_gpu = torch.jit.trace(model, (x1, x2))
#print(model_gpu.graph)

out1, out2 = model_gpu(x1, x2)
print('cuda output')
print(out1)
print(out2)

Output:

input
tensor([ -14,  127,  -24,    9, -115,  -24, -102,   -5,    5,   93,   45,  -69,
         -74,   46,  109,  -90], dtype=torch.int8)
tensor([  32,  -46,   13,   78,  109,  -84, -104,   76,   29,  -97,  -90,   73,
          17, -105,   34,  117], dtype=torch.int8)
cpu output
tensor([  18,   81,  -11,   87,   -6, -108,   50,   71,   34,   -4,  -45,    4,
         -57,  -59, -113,   27], dtype=torch.int32)
tensor([ -46,  -83,  -37,  -69,   32,   60,    2,  -81,  -24,  -66, -121,  114,
         -91, -105,   75,   49], dtype=torch.int32)
cuda output
tensor([  18,   81,  -11,   87,   -6, -108, -206,   71,   34,   -4,  -45,    4,
         -57,  -59,  143,   27], device='cuda:0', dtype=torch.int32)
tensor([ -46,  173,  -37,  -69, -224,   60,    2,  -81,  -24,  190,  135, -142,
         -91,  151,   75, -207], device='cuda:0', dtype=torch.int32)

The mismatches look like overflowing values (e.g. 50 vs. -206).

Environment

PyTorch version: 1.11.0.dev20211019+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.21.3
Libc version: glibc-2.31

Reproduced on different GPUs.

Additional information

Enabling nvfuser via:

torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_bailout_depth(20)

yields matching values.

CC @malfet as we've discussed this issue. (Initially I thought it would be ARM-specifc, but that turns out to be wrong)

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Oct 20, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Oct 20, 2021
@eellison eellison added the NNC label Oct 22, 2021
@eellison eellison removed this from Need triage in JIT Triage Oct 22, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Oct 22, 2021
@eellison eellison removed this from Need triage in JIT Triage Oct 25, 2021
deadeyegoodwin pushed a commit to triton-inference-server/server that referenced this issue Nov 10, 2021
deadeyegoodwin pushed a commit to triton-inference-server/server that referenced this issue Nov 11, 2021
deadeyegoodwin pushed a commit to triton-inference-server/server that referenced this issue Nov 11, 2021
@zeruniverse
Copy link

Is there any plan to fix this bug?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
NNC oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants