Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: cuSPARSELT not supported on your machine. When I was calling: torch._cslt_compress() #115077

Open
silence1024 opened this issue Dec 4, 2023 · 19 comments
Labels
module: binaries Anything related to official binaries that we release to users module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@silence1024
Copy link

silence1024 commented Dec 4, 2023

🐛 Describe the bug

I have this issue.
When I was running:

import torch
from torch.sparse import to_sparse_semi_structured
A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
A_sparse = to_sparse_semi_structured(A)

The error occurred:

/opt/anaconda3/envs/sm_sparse/lib/python3.8/site-packages/torch/sparse/semi_structured.py:86: >UserWarning: The PyTorch API of SparseSemiStructuredTensor is in prototype stage and will change in >the near future. Please open a Github issue for features requests and see our documentation on the >torch.sparse module for further information about the project.
warnings.warn(
Traceback (most recent call last):
File "", line 1, in
File "/opt/anaconda3/envs/sm_sparse/lib/python3.8/site-packages/torch/sparse/semi_structured.py", line 434, in to_sparse_semi_structured
return SparseSemiStructuredTensor(original_tensor, original_shape=original_tensor.shape, transposed=transposed)
File "/opt/anaconda3/envs/sm_sparse/lib/python3.8/site-packages/torch/sparse/semi_structured.py", line 212, in init
compressed_tensor = torch._cslt_compress(original_tensor)
RuntimeError: cuSPARSELT not supported on your machine.

Versions

My relevant versions are:

PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.18 (default, Sep 11 2023, 13:40:15) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-113-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA RTX A6000
GPU 8: NVIDIA A100 80GB PCIe

Nvidia driver version: 530.30.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.6.0
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.4.0
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.0
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.0
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.0
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.0
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.0
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 256
On-line CPU(s) list: 0-254
Off-line CPU(s) list: 255
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7H12 64-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 2117.314
CPU max MHz: 2600.0000
CPU min MHz: 1500.0000
BogoMIPS: 5200.05
Virtualization: AMD-V
L1d cache: 2 MiB
L1i cache: 2 MiB
L2 cache: 32 MiB
L3 cache: 256 MiB
NUMA node0 CPU(s): 0-15,128-143
NUMA node1 CPU(s): 16-31,144-159
NUMA node2 CPU(s): 32-47,160-175
NUMA node3 CPU(s): 48-63,176-191
NUMA node4 CPU(s): 64-79,192-207
NUMA node5 CPU(s): 80-95,208-223
NUMA node6 CPU(s): 96-111,224-239
NUMA node7 CPU(s): 112-127,240-254
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.1.0
[pip3] torchaudio==2.1.0
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py38h5eee18b_1
[conda] mkl_fft 1.3.8 py38h5eee18b_0
[conda] mkl_random 1.2.4 py38hdb19cb5_0
[conda] numpy 1.24.3 py38hf6e8229_1
[conda] numpy-base 1.24.3 py38h060ed82_1
[conda] pytorch 2.1.0 py3.8_cuda12.1_cudnn8.9.2_0 pytorch
[conda] pytorch-cuda 12.1 ha16c6d3_5 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 2.1.1 pypi_0 pypi
[conda] torchaudio 2.1.1 pypi_0 pypi
[conda] torchtriton 2.1.0 py38 pytorch
[conda] torchvision 0.16.1 pypi_0 pypi
[conda] triton 2.1.0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @seemethere @malfet @osalpekar @atalman @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer

@malfet malfet added module: sparse Related to torch.sparse module: binaries Anything related to official binaries that we release to users labels Dec 4, 2023
@malfet malfet added this to the 2.2.0 milestone Dec 4, 2023
@malfet malfet added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 4, 2023
@alexsamardzic
Copy link
Collaborator

Instead of:

from torch.sparse import to_sparse_semi_structured

can you try with:

from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = True

@silence1024
Copy link
Author

Thanks! It works after I add this.

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = True
A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().to("cuda:1")
A_sparse = to_sparse_semi_structured(A)

/opt/anaconda3/envs/sm_sparse/lib/python3.8/site-packages/torch/sparse/semi_structured.py:86: UserWarning: The PyTorch API of SparseSemiStructuredTensor is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.sparse module for further information about the project.
warnings.warn(

A_sparse

SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=Falsevalues=tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:1', dtype=torch.float16)metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
...,
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:1',
dtype=torch.int16))

@jcaip
Copy link
Contributor

jcaip commented Dec 4, 2023

@silence1024 Are you trying to run with cuSPARSELt or CUTLASS? Please be aware that setting _FORCE_CUTLASS = True will use the CUTLASS kernels instead of cuSPARSELt.

@silence1024
Copy link
Author

@silence1024 Are you trying to run with cuSPARSELt or CUTLASS? Please be aware that setting _FORCE_CUTLASS = True will use the CUTLASS kernels instead of cuSPARSELt.

I'm trying to run with cuSPARSELt. So there is still an issue

@jcaip
Copy link
Contributor

jcaip commented Dec 5, 2023

@silence1024 Are you building pytorch from scratch? That is the only way to currently run with cuSPARSELt v0.5.0, as v0.4.0 is only in the nightlies. If so, can you share the command used to compile pytorch?

@silence1024
Copy link
Author

@silence1024 Are you building pytorch from scratch? That is the only way to currently run with cuSPARSELt v0.5.0, as v0.4.0 is only in the nightlies. If so, can you share the command used to compile pytorch?

Thank you for your response!

  1. I installed PyTorch version 2.1.0, following the instructions available on the PyTorch official website. The installation was completed using this command:

    conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
    
  2. Currently, the version of cuSPARSELt installed on my server is v0.5.0. Are you suggesting that in order to integrate cuSPARSELt v0.5.0, I should compile PyTorch from scratch?

@jcaip
Copy link
Contributor

jcaip commented Dec 5, 2023

Okay there are two issues here:

Unfortunately the conda builds of the pytorch nightlies do not come with cuSPARSELt installed, we have a task to add conda support here: #115085

Note that you can pip install to get a nightlies built with cuSPARSELt, see my comment here: #113776 (comment)

However, this will be with cuSPARSELt v0.4.0, not v0.5.0, which we are in the process of integrating.

If you want v0.5.0 support, you will have to compile pytorch from scratch, by setting the env vars

USE_CUSPARSELT=1
CUSPARSELT_ROOT=/path/to/cusparselt_v0.5.0/download

Can I ask what you are trying to do with semi-structured sparsity? I would recommend you just pip install the nightlies and use cuSPARSELt v0.4.0, which should be sufficient for most cases.

Also our tutorial: https://pytorch.org/tutorials/prototype/semi_structured_sparse.html may be helpful as well. In general it is not recommended to use the private APIs _cslt_compress and _cslt_sparse_mm, but to just use the tensor subclass instead (which will take care of a lot of things for you, like padding).

@silence1024
Copy link
Author

Thanks for the explanation, I will try with your suggestions and update the progress here.

For the questions:

Can I ask what you are trying to do with semi-structured sparsity? I would recommend you just pip install the nightlies and use cuSPARSELt v0.4.0, which should be sufficient for most cases.

Reply: I'm trying to accelerate the inference speed of LLMs for instance LLaMA by using semi-structred sparsity. I will try with cuSPARSELt v0.4.0.

Also our tutorial: https://pytorch.org/tutorials/prototype/semi_structured_sparse.html may be helpful as well. In general it is not recommended to use the private APIs _cslt_compress and _cslt_sparse_mm, but to just use the tensor subclass instead (which will take care of a lot of things for you, like padding).

Reply:

  1. I read the tutorials. As you commented above, it forced to use CUTLASS = True, therefore, it didn't really use cuSPARSELt.
    My question is: will cuSPARSELt offer a better acceleration than CUTLASS?
  2. The error: RuntimeError: cuSPARSELT not supported on your machine. occurred when I ran the function to_sparse_semi_structured()
    The code inside this function is:
if self._FORCE_CUTLASS:
    ...
else:
    compressed_tensor = torch._cslt_compress(original_tensor)

This implies that the error would have happened if I didn't set _FORCE_CUTLASS = True.
Am I right?

@jcaip
Copy link
Contributor

jcaip commented Dec 6, 2023

Cool @silence1024 I am actually looking into this myself (LLM inference acceleration via sparsity) at the moment. So please keep me updated

The short answer is: You can't with what's currently available, at least for the text generation use case.

This is because for text generation, we generate one token at a time, so the matmul shapes are [1, hidden] @ [hidden, output]. And unfortunately there are shape constraints on both the dense and sparse matrices for sparse matmul. So we need to pad our matrix to be size [8, hidden] @ [hidden, output]. The code in the subclass will do this for you, but this padding eats into the sparse matmul speedup.

We also observed that the default alg_id for cuSPARSELt that we use is not optimal for these small sizes, I am actually working on a PR to fix this now: #115178. We're seeing around a 10% speedup over the baseline in gpt-fast (114 tok/s vs 104 tok/s) when we select the optimal alg_id.

But you'll need cuSPARSELt v0.5.0, my PR with algid=1, and to update the padding shapes (we pad to 32 currently because CUTLASS has different shape constraints). So this is kind of a lot of work to setup. I am working on releasing this in core, but it will probably be at least a week or two.

However for non batch-size 1 use cases, we can offer speedups out of the box. I.e. if you tune llama to get embeddings, or for text-classification, we can accelerate that use case, since the matmul shapes are [batch_size, hidden] @ [hidden, output].

Alternatively, you can try and see if you can get speedups with the CUTLASS kernels. It really depends on what your baseline is. If it's just eager mode python, you still may see speedups, but our baseline is torch.compile, so it's a bit more competitive.

Will cuSPARSELt offer a better acceleration than CUTLASS?

In general yes, but CUTLASS is comparable for certain cases. I would say overall cuSPARSELt is faster, especially for int8 / the llama matmul shapes. CUTLASS however comes with CUDA so there is no cuSPARSELt dependency, which is why we use it in the tutorials.

This implies that the error would have happened if I didn't set _FORCE_CUTLASS = True.

Yes I think this is a bug, in the nightlies we force CUTLASS by default, but cuSPARSELt used to be the default option, I think for the version you downloaded. So that's why you are seeing that error.

cc @atalman what's the status of the 2.1.2 release fix for: https://github.com/pytorch/pytorch/pull/114593/files?
This PR enables CUTLASS by default and would fix these kinds of errors.

@RanchiZhao
Copy link

RanchiZhao commented Dec 6, 2023

I am wondering is cusparseLt available for int8?
this is my code demo here, while I find that the result of out_sparse is not equal to out_dense/int8:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch.nn.functional as F
import torch
torch.manual_seed(114)
torch.cuda.manual_seed(114)

class Linear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, W: torch.Tensor, B: torch.Tensor):
        ctx.wcsplt = torch.ops.aten._cslt_compress(W)
        ctx.wcspltT = torch.ops.aten._cslt_compress(W.t().contiguous())
        ctx.save_for_backward(x)
        return torch.ops.aten._cslt_sparse_mm(ctx.wcsplt, dense_B=x.t(), bias=B, transpose_result=True)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        x = ctx.saved_tensors[0]
        return (
            torch.ops.aten._cslt_sparse_mm(ctx.wcsplt, grad_out.t(), transpose_result=True),
            x.t() @ grad_out,
            None
        )

m, k, n = 16384, 4096, 512
stride = 512
scale_factor = 2

# m, k, n = 16, 16, 16
# stride = 4
# scale_factor = 2
a = (torch.randn([m, k], device="cuda") * scale_factor).to(torch.int8)
W = (torch.randn([n, k], device="cuda") * scale_factor).to(torch.int8)
B = (torch.randn([n], device="cuda") * scale_factor).to(torch.int8)
W_pruned = W.clone()
for i in range(n):
    for j in range(k//stride):
        W_pruned[i][j*stride:j*stride+stride] = W_pruned[i][j*stride:j*stride+stride] * torch.tensor([1, 0, 1, 0] * (stride // 4), device=W.device)

out_sparse = Linear.apply(a, W_pruned, B)
print("out_sparse: ",out_sparse)
torch.cuda.synchronize()


a_float = a.to(torch.float32)
W_pruned_float = W_pruned.to(torch.float32)
B_float = B.to(torch.float32)

out_dense = F.linear(a_float, W_pruned_float, B_float)
out_dense_int8 = out_dense.to(torch.int8)
print("out_dense: ",out_dense)
print("out_dense_int8: ",out_dense_int8)
torch.cuda.synchronize()

@RanchiZhao
Copy link

RanchiZhao commented Dec 6, 2023

I am wondering is cusparseLt available for int8? this is my code demo here, while I find that the result of out_sparse is not equal to out_dense/int8:

maybe the error in private APIs _cslt_compress and _cslt_sparse_mm? I m not sure

@RanchiZhao
Copy link

maybe this one code is a more correct demo, where i try to avoid those private APIs:

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = False
import numpy as np
torch.manual_seed(114)
torch.cuda.manual_seed(114)

# mask Linear weight to be 2:4 sparse
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
# print(torch.nn.Parameter((mask * linear.weight)))
linear.weight = torch.nn.Parameter((mask * linear.weight * 118).to(torch.int8), requires_grad=False)
bias = torch.zeros(3072, dtype=torch.int8).cuda()

x = (torch.rand(3072, 10240).half() * 127).cuda().to(torch.int8)
print("x: ",x)
print("linear.weight: ",linear.weight)
with torch.inference_mode():
    dense_output = torch.nn.functional.linear(x.to(torch.float16), linear.weight.to(torch.float16), bias.to(torch.float16))
    out_dense_cpu = dense_output.cpu()
    dense_output_int8 = torch.from_numpy(np.clip(out_dense_cpu.numpy(), -128, 127)).type(torch.int8).cuda()

    dense_t = Timer(stmt="torch.nn.functional.linear(x, weight, bias)",
                    globals={"x": x.to(torch.float16), "weight": linear.weight.to(torch.float16), "bias": bias.to(torch.float16)}).blocked_autorange().median * 1e3
    
    # accelerate via SparseSemiStructuredTensor
    linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight), requires_grad=False)
    sparse_output_int8 = torch.nn.functional.linear(x, linear.weight, bias)
    sparse_t = Timer(stmt="torch.nn.functional.linear(x, weight, bias)",
                    globals={"x": x, "weight": linear.weight, "bias": bias}).blocked_autorange().median * 1e3


    # sparse and dense matmul are numerically equivalent
    print("dense: ", dense_output_int8)
    print("sparse: ", sparse_output_int8)
    # assert torch.allclose(sparse_output_int8, dense_output_int8, atol=1e-3)
    print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")

@alexsamardzic
Copy link
Collaborator

The int8 multiplication will produce int32 output. So the bias should be int32 and, more importantly, there is no need to clip.

@RanchiZhao
Copy link

ye, int32 is needed for bias, while the use of clip is to check the result of int8*int8 gemm is correct?

@alexsamardzic
Copy link
Collaborator

Ok, after checking the code, it seems there is a discrepancy here: CUTLASS backend produces int32 result for int8 operands (and then the bias is expected to be int32 too), while cuSPARSELt version produces either int8 or float16 result. So for cuSPARSELt version, clip is probably needed; my comment above was for CUTLASS backend, and in this case following would produce correct results (note that bias is changed to be non-zero too):

import torch
from torch.utils.benchmark import Timer

from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = True

torch.manual_seed(114)

mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
linear.weight = torch.nn.Parameter((mask * linear.weight * 118).to(torch.int8), requires_grad=False)
bias = torch.randint(-128, 127, (3072,)).to(torch.int32).cuda()
x = (torch.rand(3072, 10240).half() * 127).cuda().to(torch.int8)

with torch.inference_mode():
    x_dense = x.to(torch.half)
    weight_dense = linear.weight.to(torch.half)
    bias_dense = bias.to(torch.half)
    dense_output = torch.nn.functional.linear(x_dense, weight_dense, bias_dense)
    dense_t = Timer(stmt="torch.nn.functional.linear(x, weight, bias)",
                    globals={"x": x_dense, "weight": weight_dense, "bias": bias_dense}).blocked_autorange().median * 1e3

    # accelerate via SparseSemiStructuredTensor
    linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight), requires_grad=False)
    sparse_output = torch.nn.functional.linear(x, linear.weight, bias)
    sparse_t = Timer(stmt="torch.nn.functional.linear(x, weight, bias)",
                     globals={"x": x, "weight": linear.weight, "bias": bias}).blocked_autorange().median * 1e3

    assert torch.allclose(sparse_output.to(torch.half), dense_output, rtol=1e-3)

(An additional note: as far as timings comparison concerned, this would not be apples to apples. For proper comparison, input datatypes to linear() should be the same for both dense and sparse version.)

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Dec 8, 2023

We do have a higher order operator to control the output dtype of int8 operations. Otherwise the behavior should model that of regular dense Tensors. @jcaip does it currently match?

For example

from torch._higher_order_ops.out_dtype import out_dtype
out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)

We might want to document and make public this operator if there's additional use cases as detailed in this issue.

@alexsamardzic
Copy link
Collaborator

The problem is that dense MM is not supported for int8, so we don't have a point of reference:

>>> import torch
>>> x = torch.ones((2, 2), dtype=torch.int8, device="cuda")
>>> torch.mm(x, x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: "addmm_cuda" not implemented for 'Char'

The tensor cores itself produce int32 result for sparse GEMM with int8 operands, as mentioned here:

The integer mma.sp operation is performed with .s32 accumulators.

So CUTLASS just decides to keep it as is, while it seems cuSPARSELt makes it possible to return either int8 or float16 (that means there is an additional cast involved). When result is int32, it's natural to have bias also as int32, so this is how CUTLASS backend works at the moment. With upcoming improved epilogue support for sparse GEMM in CUTLASS, it will be easy to incorporate arbitrary casting (fused with GEMM itself - this is important for performance!), both to the result of the multiplication as well as to bias. But in any case, we should probably strive for a unification here. Maybe we should open a new issue to track this?

@jcaip
Copy link
Contributor

jcaip commented Dec 8, 2023

@cpuhrsch As mentioned there's no reference behavior for CUDA but I believe it is consistent with the CPU implementation, as int8int8 defaults to int8 output

>>> A = torch.ones(8, 8).to(torch.int8)
>>> torch.mm(A, A)
tensor([[8, 8, 8, 8, 8, 8, 8, 8],
        [8, 8, 8, 8, 8, 8, 8, 8],
        [8, 8, 8, 8, 8, 8, 8, 8],
        [8, 8, 8, 8, 8, 8, 8, 8],
        [8, 8, 8, 8, 8, 8, 8, 8],
        [8, 8, 8, 8, 8, 8, 8, 8],
        [8, 8, 8, 8, 8, 8, 8, 8],
        [8, 8, 8, 8, 8, 8, 8, 8]], dtype=torch.int8)
>>>

We do have a higher order operator to control the output dtype of int8 operations. Otherwise the behavior should model that of regular dense Tensors. @jcaip does it currently match?

No, currently we specify out_dtype as an arg that's passed into _cslt_sparse_mm. Yeah, @alexsamardzic can you make an issue to (consolidate semi-structured sparse mixed dtype behavior into a common UI). Does that capture the issue correctly?

Also @alexsamardzic @RanchiZhao just FYI, currently cuSPARSELt v0.4.0 doesn't support i8i8->i32, but this was added in v0.5.0 by our request and we have a PR to enable functionality #110499, but we'll need to update cslt version in nightlies (and CI).

@alexsamardzic
Copy link
Collaborator

New issue created: #115420.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: binaries Anything related to official binaries that we release to users module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants