Skip to content

Commit

Permalink
nn.Linear with BSR inputs: spare the user from explicit Triton kernel…
Browse files Browse the repository at this point in the history
… registrations

ghstack-source-id: 7958af1d4d433c13833604e40268d8e5d1e813af
Pull Request resolved: #98403
  • Loading branch information
nikitaved committed May 25, 2023
1 parent 4882cd0 commit 2dbd27b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 28 deletions.
16 changes: 16 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5728,6 +5728,22 @@ def no_pool():
self.assertEqual(len(get_cudagraph_segments(pool)), 0)


def test_no_triton_on_import(self):
script = """import sys; import torch; torch.rand(2, device='cuda'); exit(2 if "triton' in sys.modules else 0)
"""

try:
subprocess.check_output(
[sys.executable, '-c', script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)))
except subprocess.CalledProcessError as e:
if e.returncode == 2:
self.assertTrue(False, "Triton was imported when importing torch!")


instantiate_parametrized_tests(TestCuda)

if __name__ == '__main__':
Expand Down
40 changes: 13 additions & 27 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3312,29 +3312,6 @@ def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout):
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())


class _TritonLibrary(object):
lib = torch.library.Library("triton", "DEF")
ops = {}

@classmethod
def probablyRegisterOp(cls, op_key, full_schema, op_impl, dispatch_key):
if op_key not in cls.ops:
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops[op_key] = op_impl

return cls.ops[op_key]

class _WrappedKernel(object):
def __init__(self, kernel):
self.kernel = kernel
self.kernel_invoked = False

def __call__(self, *args, **kwargs):
res = self.kernel(*args, **kwargs)
self.kernel_invoked = True
return res

def skipIfNoTriton(cls):
from torch._inductor.utils import has_triton

Expand All @@ -3359,18 +3336,26 @@ class TestSparseCompressedTritonKernels(TestCase):
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(),
"Skipped for deploy and internal with remote GPUs")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial
from torch.sparse._triton_ops import bsr_dense_mm

kernel = _TritonLibrary.probablyRegisterOp(
def kernel_impl(*args, **kwargs):
return bsr_dense_mm(*args, skip_checks=True, **kwargs)

kernel = torch._TritonLibrary.registerOp(
"_triton_bsr_dense_mm_out",
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
_WrappedKernel(lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs)),
kernel_impl,
"SparseCsrCUDA"
)

# kernel != kernel_impl means dispatch was already registered.
# This is exactly what we need!
self.assertTrue(kernel is not kernel_impl)

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

Expand Down Expand Up @@ -3427,7 +3412,8 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
@onlyCUDA
@skipIfRocm
@dtypes(torch.half)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(),
"Skipped for deploy and internal with remote GPUs")
def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
from torch.sparse._triton_ops import bsr_dense_mm

Expand Down
17 changes: 16 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _running_with_deploy():
else:
from .torch_version import __version__ as __version__

from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union
import builtins

__all__ = [
Expand Down Expand Up @@ -1693,5 +1693,20 @@ def _sparse_coo_tensor_unsafe(*args, **kwargs):
# Register MPS specific decomps
torch.backends.mps._init()

if not _running_with_deploy():
class _TritonLibrary(object):
lib = torch.library.Library("triton", "DEF")
ops_table: Dict[Tuple[str, str], Callable] = {}

@classmethod
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
if (op_key, dispatch_key) not in cls.ops_table:
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops_table[(op_key, dispatch_key)] = op_impl

return cls.ops_table[(op_key, dispatch_key)]


from . import _logging
_logging._init_logs()
39 changes: 39 additions & 0 deletions torch/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import contextlib
import os
import sys
import importlib
import torch
from torch.types import Device
import traceback
Expand Down Expand Up @@ -1137,6 +1139,43 @@ def _dtype(self):
torch._storage_classes.add(ComplexDoubleStorage)
torch._storage_classes.add(ComplexFloatStorage)


class _WrappedTritonKernel(object):
""" Just a simple wrapper to store some metadata for testing purposes.
"""

def __init__(self, kernel):
self.kernel = kernel
self.kernel_invoked = False

def __call__(self, *args, **kwargs):
res = self.kernel(*args, **kwargs)
self.kernel_invoked = True
return res


def _register_triton_kernels():
if torch._running_with_deploy():
return

@_WrappedTritonKernel
def kernel_impl(*args, **kwargs):
from torch.sparse._triton_ops import bsr_dense_mm
return bsr_dense_mm(*args, skip_checks=True, **kwargs)

has_triton = importlib.util.find_spec("triton") is not None
if has_triton:
torch._TritonLibrary.registerOp(
"_triton_bsr_dense_mm_out",
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
kernel_impl,
"SparseCsrCUDA"
)


_lazy_call(_register_triton_kernels)


from . import sparse
from . import profiler
from . import nvtx
Expand Down

0 comments on commit 2dbd27b

Please sign in to comment.