Skip to content

Commit

Permalink
nn.Linear: dispatch to bsr_dense_mm for half and bfloat16
Browse files Browse the repository at this point in the history
ghstack-source-id: bf2abe355707ed10c0fd913a002a3a8cbb7f2e9f
Pull Request resolved: #94825
  • Loading branch information
nikitaved committed Feb 15, 2023
1 parent a8fd5d3 commit 463feb3
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 5 deletions.
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6486,6 +6486,12 @@
SparseCPU: s_addmm_sparse_dense_cpu_
SparseCUDA: s_addmm_sparse_dense_cuda_

- func: _triton_bsr_dense_mm(Tensor bsr, Tensor dense) -> Tensor
variants: function
dispatch:
CUDA: triton_bsr_dense_mm
autogen: _triton_bsr_dense_mm.out

- func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/native/sparse/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/SparseCsrTensorUtils.h>

// Required for checking whether Triton kernels are available
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <ATen/core/dispatch/Dispatcher.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
Expand All @@ -12,6 +16,7 @@
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/_triton_bsr_dense_mm.h>
#endif

namespace at {
Expand Down Expand Up @@ -70,6 +75,31 @@ Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& s
blocksize = {values.size(-2), values.size(-1)};
}

// No stable support for ROCM in Triton yet.
#ifndef USE_ROCM
// Triton works only with blocksizes which are powers of 2.
const auto is_power_of_2 = [](int64_t v) -> bool {
return !(v & (v - 1));
};

// Dtype and blocksize checks for potential Triton usage.
if ((strided.scalar_type() == ScalarType::Half
|| strided.scalar_type() == ScalarType::BFloat16)
&& is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
&& (blocksize[0] >= 16) && (blocksize[1] >= 16)
// lhs is retiled to (b0, b1) while rhs is to (b1, b0),
// so the result is tiled to (b0, b0) and we need to make
// sure that dense.size(-1) is divisible by b0.
&& n % blocksize[0] == 0) {
const auto triton_kernel = c10::Dispatcher::singleton()
.findOp(torch::jit::parseName("aten::_triton_bsr_dense_mm"));
// Call Triton only if dispatch key was overwritten.
if (triton_kernel->hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
return at::_triton_bsr_dense_mm_out(result, compressed, strided);
}
}
#endif

// (..., r, c) -> (..., r / b0, c / b1, b0, b1)
// NOTE: this function ALWAYS creates a view upon successful execution.
const auto tile_tensor = [compressed_layout](
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/sparse/SparseMatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,5 @@ Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
return output;
}


} // namespace native
} // namespace at
7 changes: 7 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -634,5 +634,12 @@ Tensor _sparse_csr_prod_cuda(const Tensor& input, IntArrayRef dims_to_reduce, bo
return result;
}

Tensor triton_bsr_dense_mm(
const Tensor& bsr,
const Tensor& dense) {
TORCH_CHECK(false, "_triton_bsr_dense_mm: Triton kernel should be overwritten in Python.");
return Tensor {};
}

} // namespace native
} // namespace at
2 changes: 2 additions & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,8 @@ aten::_transformer_encoder_layer_fwd
aten::_transformer_encoder_layer_fwd.out
aten::_trilinear
aten::_trilinear.out
aten::_triton_bsr_dense_mm
aten::_triton_bsr_dense_mm.out
aten::_triton_multi_head_attention
aten::_triton_multi_head_attention.out
aten::_triton_scaled_dot_attention
Expand Down
31 changes: 27 additions & 4 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,9 +1475,19 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from torch._inductor.utils import has_triton
from torch.sparse._triton_ops import bsr_dense_mm

if not has_triton():
if not has_triton() or bsr_dense_mm is None:
self.skipTest("Triton is not available.")

kernel_invoked = [False]

lib = torch.library.Library("aten", "IMPL")

def impl(*args, **kwargs):
kernel_invoked[0] = True
return bsr_dense_mm(*args, skip_checks=True, **kwargs)

lib.impl("aten::_triton_bsr_dense_mm", impl, "SparseCsrCUDA")

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

Expand All @@ -1499,16 +1509,29 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):

bsr = bsr.to_sparse_bsr(block_size)

res_tri = bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
if bsr.dim() == 2:
# Test against linear to check dispatch.
res_tri = torch.nn.functional.linear(dense, bsr)
res_dense = torch.nn.functional.linear(dense, bsr.to_dense())

# Check dispatch worked with non-trivial outputs
if m > 0 and n > 0 and k > 0:
self.assertTrue(kernel_invoked[0])
kernel_invoked[0] = False
else:
# Otherwise check correctness against bmm
# since nn.linear does not support bsr.dim() > 2.
res_tri = torch._triton_bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
self.assertEqual(res_tri, res_dense)

res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
# check whether bsr_dense_mm handles different grid sizes
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
res_tri = bsr_dense_mm(
res_tri = torch.sparse._triton_ops.bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
Expand Down

0 comments on commit 463feb3

Please sign in to comment.