Skip to content

Commit

Permalink
nn.Linear with BSR inputs: spare the user from explicit Triton kernel…
Browse files Browse the repository at this point in the history
… registrations

ghstack-source-id: d073b22acdbcb0d651efeeb1a45b4ffa129d66e7
Pull Request resolved: #98403
  • Loading branch information
nikitaved committed Apr 14, 2023
1 parent 14b39fe commit cfd58af
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 25 deletions.
34 changes: 9 additions & 25 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3310,29 +3310,6 @@ def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout):
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())


class _TritonLibrary(object):
lib = torch.library.Library("triton", "DEF")
ops = {}

@classmethod
def probablyRegisterOp(cls, op_key, full_schema, op_impl, dispatch_key):
if op_key not in cls.ops:
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops[op_key] = op_impl

return cls.ops[op_key]

class _WrappedKernel(object):
def __init__(self, kernel):
self.kernel = kernel
self.kernel_invoked = False

def __call__(self, *args, **kwargs):
res = self.kernel(*args, **kwargs)
self.kernel_invoked = True
return res

def skipIfNoTriton(cls):
from torch._inductor.utils import has_triton

Expand Down Expand Up @@ -3362,13 +3339,20 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial
from torch.sparse._triton_ops import bsr_dense_mm

kernel = _TritonLibrary.probablyRegisterOp(
def kernel_impl(*args, **kwargs):
return bsr_dense_mm(*args, skip_checks=True, **kwargs)

kernel = torch._TritonLibrary.probablyRegisterOp(
"_triton_bsr_dense_mm_out",
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
_WrappedKernel(lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs)),
kernel_impl,
"SparseCsrCUDA"
)

# kernel != kernel_impl means dispatch was already registered.
# This is exactly what we need!
self.assertTrue(kernel is not kernel_impl)

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

Expand Down
43 changes: 43 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,5 +1650,48 @@ def _sparse_coo_tensor_unsafe(*args, **kwargs):
return torch.sparse_coo_tensor(*args, **kwargs)


class _TritonLibrary(object):
lib = torch.library.Library("triton", "DEF")
ops_table = {}

@classmethod
def probablyRegisterOp(cls, op_key, full_schema, op_impl, dispatch_key):
if op_key not in cls.ops_table:
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops_table[op_key] = op_impl

return cls.ops_table[op_key]


class _WrappedTritonKernel(object):
""" Just a simple wrapper to store some metadata for testing purposes.
"""

def __init__(self, kernel):
self.kernel = kernel
self.kernel_invoked = False

def __call__(self, *args, **kwargs):
res = self.kernel(*args, **kwargs)
self.kernel_invoked = True
return res


def _register_triton_kernels():
from torch.sparse._triton_ops import bsr_dense_mm

if bsr_dense_mm is not None:
_TritonLibrary.probablyRegisterOp(
"_triton_bsr_dense_mm_out",
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
_WrappedTritonKernel(lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs)),
"SparseCsrCUDA"
)


if torch.cuda.is_available():
torch.cuda._lazy_call(_register_triton_kernels)

from . import _logging
_logging._init_logs()

0 comments on commit cfd58af

Please sign in to comment.