Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Linear with BSR inputs: spare the user from explicit Triton kernel registrations #98403

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
51146ac
nn.Linear with BSR inputs: spare the user from explicit Triton kernel…
nikitaved Apr 5, 2023
7daa9a9
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved Apr 5, 2023
cc85de7
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved Apr 5, 2023
a24e2a7
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved Apr 6, 2023
9b755ca
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved Apr 12, 2023
96e60fa
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved Apr 14, 2023
8d42793
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved Apr 18, 2023
ff1a274
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved Apr 18, 2023
613f509
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved Apr 18, 2023
f9a17aa
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 12, 2023
6c328e8
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 12, 2023
2c07540
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 12, 2023
8b51bc6
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 15, 2023
b81cf98
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 15, 2023
7ab8829
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 15, 2023
c066f2e
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 24, 2023
fec4a83
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 25, 2023
d111c12
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 25, 2023
7a33b0e
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 25, 2023
08f7a6a
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 30, 2023
6f9af1d
Update test/test_cuda.py
malfet May 31, 2023
cec539a
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
nikitaved May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 43 additions & 25 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3312,29 +3312,6 @@ def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout):
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())


class _TritonLibrary(object):
lib = torch.library.Library("triton", "DEF")
ops = {}

@classmethod
def probablyRegisterOp(cls, op_key, full_schema, op_impl, dispatch_key):
if op_key not in cls.ops:
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops[op_key] = op_impl

return cls.ops[op_key]

class _WrappedKernel(object):
def __init__(self, kernel):
self.kernel = kernel
self.kernel_invoked = False

def __call__(self, *args, **kwargs):
res = self.kernel(*args, **kwargs)
self.kernel_invoked = True
return res

def skipIfNoTriton(cls):
from torch._inductor.utils import has_triton

Expand All @@ -3361,16 +3338,28 @@ class TestSparseCompressedTritonKernels(TestCase):
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
import sys
nikitaved marked this conversation as resolved.
Show resolved Hide resolved

nikitaved marked this conversation as resolved.
Show resolved Hide resolved
if torch._running_with_deploy():
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
self.skipTest("Skipped for torch_deploy")

from functools import partial
from torch.sparse._triton_ops import bsr_dense_mm

kernel = _TritonLibrary.probablyRegisterOp(
def kernel_impl(*args, **kwargs):
return bsr_dense_mm(*args, skip_checks=True, **kwargs)

kernel = torch._TritonLibrary.registerOp(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be removed now that it is done during cuda init?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed for the tests below:

  • we check that it is not possible to overwrite an already registered kernel.
  • we check that the kernel was actually invoked

"_triton_bsr_dense_mm_out",
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
_WrappedKernel(lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs)),
kernel_impl,
"SparseCsrCUDA"
)

# kernel != kernel_impl means dispatch was already registered.
# This is exactly what we need!
self.assertTrue(kernel is not kernel_impl)

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

Expand Down Expand Up @@ -3468,6 +3457,35 @@ def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
out = torch.rand(32, 32, 2, dtype=dtype, device=device).transpose(0, -1)
bsr_dense_mm(lhs, rhs, out=out)

@onlyCUDA
def test_no_triton_on_import(self):
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
import os
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
import subprocess
import sys

script = """
import sys
import torch

a = torch.rand(2, device="cuda")

if "triton" in sys.modules:
exit(2)
else:
exit(0)

"""
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
try:
subprocess.check_output(
[sys.executable, '-c', script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)))
except subprocess.CalledProcessError as e:
if e.returncode == 2:
self.assertTrue(False, "Triton was imported when importing torch!")


# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
instantiate_device_type_tests(TestSparseCSR, globals())
Expand Down
17 changes: 16 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _running_with_deploy():
else:
from .torch_version import __version__ as __version__

from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union
import builtins

__all__ = [
Expand Down Expand Up @@ -1693,5 +1693,20 @@ def _sparse_coo_tensor_unsafe(*args, **kwargs):
# Register MPS specific decomps
torch.backends.mps._init()

if not _running_with_deploy():
class _TritonLibrary(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should move this to another file, maybe torch.backends.triton?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikitaved - Can you address this one?

lib = torch.library.Library("triton", "DEF")
ops_table: Dict[Tuple[str, str], Callable] = {}

@classmethod
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
if (op_key, dispatch_key) not in cls.ops_table:
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops_table[(op_key, dispatch_key)] = op_impl

return cls.ops_table[(op_key, dispatch_key)]


from . import _logging
_logging._init_logs()
40 changes: 40 additions & 0 deletions torch/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,46 @@ def _dtype(self):
torch._storage_classes.add(ComplexDoubleStorage)
torch._storage_classes.add(ComplexFloatStorage)


class _WrappedTritonKernel(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should move all of this to triton specific code.
It makes sense to me to have this file check the triton backends file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this has nothing to do with Triton, i.e. it should be called something like IsCallableInvokedWrapper, because it can be attached to any callable to inspect whether or not it was called by anyone.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a generic util file we can put that there for sure. If we only use it within these triton bindings, it's fine to put it there for now I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD @malfet - How should @nikitaved best address this?

""" Just a simple wrapper to store some metadata for testing purposes.
"""

def __init__(self, kernel):
self.kernel = kernel
self.kernel_invoked = False

def __call__(self, *args, **kwargs):
res = self.kernel(*args, **kwargs)
self.kernel_invoked = True
return res


Comment on lines +1143 to +1156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you plan to remove it in the final version of PR?

Copy link
Collaborator Author

@nikitaved nikitaved May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to? Otherwise how could I test that the dispatch in nn.Linear actually worked? Unfortunately, Library does not provide any query interface...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But otherwise, this introduces overhead in the actual usecases of the the kernel, outside of test, isn't it?

Copy link
Collaborator Author

@nikitaved nikitaved May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but my opinion on that:

  • It is still an experimental feature, so a comprehensive test coverage with some overhead is better for my peace of mind. However, I am very much open to suggestions, as I hope this approach is just a temporary solution.
  • This overhead can be tolerated for now as this operation is quite computationally heavy. The target sparsity level is around 50%, so we can expect it to be at least as slow as regular gemm.

def _register_triton_kernels():
import sys
import importlib
nikitaved marked this conversation as resolved.
Show resolved Hide resolved

if sys.executable == "torch_deploy":
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
return

@_WrappedTritonKernel
def kernel_impl(*args, **kwargs):
from torch.sparse._triton_ops import bsr_dense_mm
return bsr_dense_mm(*args, skip_checks=True, **kwargs)

has_triton = importlib.util.find_spec("triton") is not None
if has_triton:
Comment on lines +1166 to +1167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, please help me understand the logic here: we would only register the kernel if somebody imported trition previously, which would not be the case for any of the eager workflows, isn't it?

Copy link
Collaborator Author

@nikitaved nikitaved May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the registration takes place if triton is importable.

In [1]: import importlib, sys

In [2]: importlib.util.find_spec("triton") is None
Out[2]: False

In [3]: "triton" in sys.modules
Out[3]: False

And then triton will be imported upon the need.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context, this is basically what import triton does on the first time:

spec = importlib.util.find_spec("triton")
triton = importlib.util.module_from_spec(spec)
spec.loader.exec_module(triton)

find_spec searches for the module, then exec_module is where the module's initialization code is actually run . So this checks that a module named "triton" exists, but doesn't run any code or guarantee that it will import successfully.

torch._TritonLibrary.registerOp(
"_triton_bsr_dense_mm_out",
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
kernel_impl,
"SparseCsrCUDA"
)


_lazy_call(_register_triton_kernels)


from . import sparse
from . import profiler
from . import nvtx
Expand Down