Skip to content

Commit

Permalink
nn.Linear: dispatch to bsr_dense_mm for half and bfloat16
Browse files Browse the repository at this point in the history
ghstack-source-id: 90a98a5b496a6da02080d1fae976f1c3d7de26bd
Pull Request resolved: #94825
  • Loading branch information
nikitaved committed Apr 3, 2023
1 parent 1ea528e commit b75f23e
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 54 deletions.
36 changes: 36 additions & 0 deletions aten/src/ATen/native/sparse/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/SparseCsrTensorUtils.h>

// Required for checking whether Triton kernels are available
#include <ATen/core/dispatch/Dispatcher.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -73,6 +76,39 @@ Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& s
blocksize = {values.size(-2), values.size(-1)};
}

// No stable support for ROCM in Triton yet.
#ifndef USE_ROCM
// Triton works only with blocksizes which are powers of 2.
const auto is_power_of_2 = [](int64_t v) -> bool {
return !(v & (v - 1));
};

// Dtype and blocksize checks for potential Triton usage.
if ((strided.scalar_type() == ScalarType::Half
|| strided.scalar_type() == ScalarType::BFloat16)
&& is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
&& (blocksize[0] >= 16) && (blocksize[1] >= 16)
// lhs is retiled to (b0, b1) while rhs is to (b1, b0),
// so the result is tiled to (b0, b0) and we need to make
// sure that dense.size(-1) is divisible by b0.
&& n % blocksize[0] == 0) {
try {
const auto triton_kernel = c10::Dispatcher::singleton()
.findSchemaOrThrow("triton::_triton_bsr_dense_mm_out", "")
.typed<Tensor&(const Tensor&, const Tensor&, Tensor&)>();
// Call Triton only if dispatch key was overwritten.
// This is not strictly necessary since the definition is done in Python,
// but we leave it here for extra safety.
if (triton_kernel.hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
return triton_kernel.call(compressed, strided, result);
}
} catch (const std::exception& e) {
// The schema is not defined and/or the key is not overwritten,
// so skip and execute the code below.
}
}
#endif

// (..., r, c) -> (..., r / b0, c / b1, b0, b1)
// NOTE: this function ALWAYS creates a view upon successful execution.
const auto tile_tensor = [compressed_layout](
Expand Down
168 changes: 114 additions & 54 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,60 +1460,6 @@ def run_test_block_addmm_addmv(self,
self.assertEqual(actual, out)
self.assertEqual(actual, expected, lambda msg: f"{msg}\na={a}\nc={c}\nb={b}\nalpha={alpha} beta={beta}")

@parametrize("block_size", [16, 32, 64])
@parametrize("index_dtype", [torch.int32, torch.int64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [])
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial

from torch._inductor.utils import has_triton
from torch.sparse._triton_ops import bsr_dense_mm

if not has_triton():
self.skipTest("Triton is not available.")

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
batches = [(), (2,)]
size = [128, 256, 0]

# Whether to make inputs orthogonal so that the product is zero
make_orthogonal = [True, False]

for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
bsr = tensor(bs + (m, k))
# NOTE: do not get confused, it will be transposed
dense = tensor(bd + (n, k))

if is_ortho:
bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)

bsr = bsr.to_sparse_bsr(block_size)

res_tri = bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
self.assertEqual(res_tri, res_dense)

# check whether bsr_dense_mm handles different grid sizes
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
res_tri = bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
is_sparse_rowspace_mode=is_sparse_rowspace
)
self.assertEqual(res_tri, res_dense)

# TODO: block_size 1 is broken
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
Expand Down Expand Up @@ -3364,9 +3310,123 @@ def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout):
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())


class _TritonLibrary(object):
lib = torch.library.Library("triton", "DEF")
ops = {}

@classmethod
def probablyRegisterOp(cls, op_key, full_schema, op_impl, dispatch_key):
if op_key not in cls.ops:
cls.lib.define(full_schema)
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
cls.ops[op_key] = op_impl

return cls.ops[op_key]

class _WrappedKernel(object):
def __init__(self, kernel):
self.kernel = kernel
self.kernel_invoked = False

def __call__(self, *args, **kwargs):
res = self.kernel(*args, **kwargs)
self.kernel_invoked = True
return res

def skipIfNoTriton(cls):
from torch._inductor.utils import has_triton

# no-op if triton is present
if has_triton():
return cls
else:

@functools.wraps(cls, updated=())
class skipped_cls(cls):
def setUp(self):
self.skipTest("Triton is not available.")

return skipped_cls

@skipIfNoTriton
class TestSparseCompressedTritonKernels(TestCase):

@parametrize("block_size", [16, 32, 64])
@parametrize("index_dtype", [torch.int32, torch.int64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [])
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial
from torch.sparse._triton_ops import bsr_dense_mm

kernel = _TritonLibrary.probablyRegisterOp(
"_triton_bsr_dense_mm_out",
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
_WrappedKernel(lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs)),
"SparseCsrCUDA"
)

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
batches = [(), (2,)]
size = [128, 256, 0]

# Whether to make inputs orthogonal so that the product is zero
make_orthogonal = [True, False]

for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
bsr = tensor(bs + (m, k))
# NOTE: do not get confused, it will be transposed
dense = tensor(bd + (n, k))

if is_ortho:
bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)

bsr = bsr.to_sparse_bsr(block_size)

if bsr.dim() == 2:
# Test against linear to check dispatch.
res_tri = torch.nn.functional.linear(dense, bsr)
res_dense = torch.nn.functional.linear(dense, bsr.to_dense())

# Check dispatch worked with non-trivial outputs
if m > 0 and n > 0 and k > 0:
self.assertTrue(kernel.kernel_invoked)
kernel.kernel_invoked = False
else:
# Otherwise check correctness against bmm
# since nn.linear does not support bsr.dim() > 2.
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
res_tri_out = torch.empty_like(res_dense)
res_tri = kernel(bsr, dense.transpose(-2, -1), out=res_tri_out)
self.assertTrue(res_tri is res_tri_out)
self.assertEqual(res_tri, res_dense)

res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
# check whether bsr_dense_mm handles different grid sizes
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
res_tri = torch.sparse._triton_ops.bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
is_sparse_rowspace_mode=is_sparse_rowspace
)
self.assertEqual(res_tri, res_dense)


# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
instantiate_device_type_tests(TestSparseCSR, globals())
instantiate_device_type_tests(TestSparseCompressed, globals())
instantiate_device_type_tests(TestSparseCompressedTritonKernels, globals())

if __name__ == '__main__':
run_tests()

0 comments on commit b75f23e

Please sign in to comment.