Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

GEMM with int8 datatype throws RuntimeError on GPU #49890

Open
ilovepytorch opened this issue Dec 28, 2020 · 13 comments
Open

GEMM with int8 datatype throws RuntimeError on GPU #49890

ilovepytorch opened this issue Dec 28, 2020 · 13 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ilovepytorch
Copy link

ilovepytorch commented Dec 28, 2020

馃悰 Bug

To Reproduce

import torch

def gemm(CPU):
  if CPU == True:
    A = torch.randint(1,10,[3,3]).type(torch.int8)
    B = torch.randint(1,10,[3,3]).type(torch.int8)
    C = A@B

  else:
    A = torch.randint(1,10,[3,3]).type(torch.int8).cuda()
    B = torch.randint(1,10,[3,3]).type(torch.int8).cuda()
    C = A@B

gemm(True)
gemm(False)

Simply run python3 gemm.py will cause the following RuntimeError:

Traceback (most recent call last):
  File "gemm.py", line 15, in <module>
    gemm(False)
  File "gemm.py", line 12, in gemm
    C = A@B
RuntimeError: "addmm_cuda" not implemented for 'Char

gemm(True) passes, namely running GEMM with with data int8 on CPU is fine.

Expected behavior

GEMM with data int8 on GPU should be supported.

Environment

PyTorch version: 1.7.0+cu101
Is debug build: True
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: CentOS Linux release 7.7.1908 (Core) (x86_64)
GCC version: (GCC) 7.2.0
Clang version: 9.0.0 (tags/RELEASE_900/final)
CMake version: version 3.12.2

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Tesla V100-PCIE-16GB

Nvidia driver version: 455.45.01
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.7.5.0
/usr/lib64/libcudnn.so.8.0.5
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.18.0
[pip3] numpydoc==0.7.0
[pip3] pytorch-sublstm==0.0.2
[pip3] torch==1.7.0+cu101
[pip3] torch-tvm==0.0.1
[pip3] torchaudio==0.7.0
[pip3] torchfile==0.1.0
[pip3] torchnet==0.0.4
[pip3] torchvision==0.8.1+cu101
[conda] blas 1.0 mkl
[conda] cuda100 1.0 0 pytorch
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] magma-cuda101 2.5.2 1 pytorch
[conda] mkl 2020.0 166
[conda] mkl-include 2020.0 166
[conda] mkl-service 2.3.0 py36he904b0f_0
[conda] mkl_fft 1.0.15 py36ha843d7b_0
[conda] mkl_random 1.1.0 py36hd6b4f25_0
[conda] numpy 1.18.0 pypi_0 pypi
[conda] numpydoc 0.7.0 py36h18f165f_0
[conda] pytorch-sublstm 0.0.2 pypi_0 pypi
[conda] torch 1.7.0+cu101 pypi_0 pypi
[conda] torch-tvm 0.0.1 pypi_0 pypi
[conda] torchaudio 0.7.0 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchnet 0.0.4 pypi_0 pypi
[conda] torchvision 0.8.1+cu101 pypi_0 pypi

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 28, 2020
@mruberry
Copy link
Collaborator

Thanks for suggesting this feature, @ilovepytorch. We've had requests for integer matmul support in the past, too. In fact, there's probably another issue with this same request.

I think we would accept a PR implementing integer matrix multiplication.

@jonykarki
Copy link
Contributor

jonykarki commented Dec 29, 2020

I want to try to work on this. Could you give me some guide on how to start on this

@mruberry
Copy link
Collaborator

I want to try to work on this. Could you give me some guide on how to start on this

Supporting integer matrix multiplication is a challenging task suitable for PyTorch and GEMM experts, so there probably won't be a guide for this. Would you be interested in a simpler task?

@jonykarki
Copy link
Contributor

jonykarki commented Dec 29, 2020

Supporting integer matrix multiplication is a challenging task suitable for PyTorch and GEMM experts, so there probably won't be a guide for this. Would you be interested in a simpler task?

That's what I thought. The problem just sounded interesting. Yes, sure i'll try something simpler.

@jianyuh
Copy link
Member

jianyuh commented Dec 29, 2020

To support fast INT8 GEMM on GPUs, I think we need to change https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDABlas.cpp to add the support of INT8 GEMM (with cublasGemmEx where the A/B data types are CUDA_R_8I, CUDA_R_8I, and the C data type is CUDA_R_32I.

cc @ngimel

@ngimel
Copy link
Collaborator

ngimel commented Dec 29, 2020

cublasGemmEx API is reasonable in a sense that it produces int32 result and thus won't be as prone to overflow as regular int8 matrix multiply. However, pytorch does not provide a convenient way to specify "multiply 2 int8 matrices and produce int32 result". This is a problem that's better discussed in the context of quantization and APIs that are needed for quantization.

@vadimkantorov
Copy link
Contributor

Should there be a dtype argument for torch.matmul?

@ilovepytorch
Copy link
Author

Should there be a dtype argument for torch.matmul?

According to the current doc, there isn't. From my perspective, I think it would be better if the user does not need to explicitly pass the argument for dtype in torch.matmul API. The framework should automatically dispatch the right call into the library with the right datatype, ideally.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Mar 23, 2021

But I think it would be good to allow to configure custom accumulation dtype (if it makes sense) and output dtype if the user for some reason wants to supply these (e.g. before we have finalized defaults for casting)

@ilovepytorch
Copy link
Author

ilovepytorch commented Mar 23, 2021

@jonykarki @mruberry
Sorry to bother you again!
For int8 datatype on GPU (namely adding .type(torch.int8).cuda() when initializing the data), actually there are a number of APIs (including some frequently used ones like torch.matmul, torch.mm, torch.nn.AvgPool2d, torch.nn.AvgPool3d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d) that will throw RuntimeError, shown below:

torch.addbmm
torch.addmm
torch.addmv
torch.baddbmm
torch.ceil
torch.chain_matmul
torch.cholesky_inverse
torch.cholesky_solve
torch.cholesky #actually it's because torch.mm is broken for int8
torch.einsum
torch.floor
torch.imag
torch.inverse
torch.logcumsumexp
torch.logit
torch.lu_solve
torch.lu_unpack
torch.matmul
torch.mm
torch.nn.AdaptiveAvgPool1d
torch.nn.AdaptiveAvgPool2d
torch.nn.AdaptiveAvgPool3d
torch.nn.AdaptiveMaxPool1d
torch.nn.AdaptiveMaxPool2d
torch.nn.AdaptiveMaxPool3d
torch.nn.AlphaDropout
torch.nn.AvgPool2d
torch.nn.AvgPool3d
torch.nn.BatchNorm2d
torch.nn.BatchNorm3d
torch.nn.CELU
torch.nn.CrossEntropyLoss
torch.nn.CTCLoss
torch.nn.Dropout
torch.nn.ELU
torch.nn.Fold
torch.nn.FractionalMaxPool2d
torch.nn.GELU
torch.nn.GroupNorm
torch.nn.Hardshrink
torch.nn.Hardswish
torch.nn.LayerNorm
torch.nn.LeakyReLU
torch.nn.LocalResponseNorm
torch.nn.LogSigmoid
torch.nn.LogSoftmax
torch.nn.MaxPool1d
torch.nn.MaxPool2d
torch.nn.MaxPool3d
torch.nn.MSELoss
torch.nn.NLLLoss
torch.nn.ReplicationPad3d
torch.nn.RReLU
torch.nn.SELU
torch.nn.SiLU
torch.nn.Softmax2d
torch.nn.Softmax
torch.nn.Softmin
torch.nn.Softplus
torch.nn.Softshrink
torch.qr
torch.svd
torch.symeig
torch.tensordot
torch.triangular_solve
torch.trunc

To be frank, a number of them are not that critical to fix, like torch.ceil (mainly used for floating-point computations), while fixing some of them (especially those frequently-used) can be quite beneficial, like GEMM-related, MaxPool/AvgPool-related and Loss-related. Sometimes when users want to train/test a neural network using special datatypes, fixing those APIs will be helpful to some extent:)

@ilovepytorch
Copy link
Author

But I think it would be good to allow to configure custom accumulation dtype (if it makes sense) and output dtype if the user for some reason wants to supply these (e.g. before we have finalized defaults for casting)

Yeah, this makes sense to me! Sorry I am not an expert at designing APIs:) Please feel free to discuss this with other experts as well.

@ngimel
Copy link
Collaborator

ngimel commented Mar 23, 2021

What's the motivation behind enabling int8 for all those operations? If it's quantization and quantization-aware training, pytorch supports a number of operations for quantized tensors, and in general, requests should be discussed with quantization team. Enabling training operations for plain int8 type doesn't make much sense to me, tbh.

@ilovepytorch
Copy link
Author

Thanks for your answers! I have just checked quantization in Pytorch, and found that "At the moment PyTorch doesn鈥檛 provide quantized operator implementations on CUDA" and this is for "for future work". I believe the support for int8 on GPU will deliver better performance (especially most-frequently used GEMM), and this seems to be the trend. Nvidia A100 has Tensor Core support for int8 GEMM, and I suppose Pytorch could benefit a lot from exploiting those hardware intrinsics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants