Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Linear: dispatch to bsr_dense_mm for half and bfloat16 #94825

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fd4af14
nn.Linear: dispatch to bsr_dense_mm for half and bfloat16
nikitaved Feb 14, 2023
2c759e0
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Feb 14, 2023
018ba1d
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Feb 15, 2023
5d29b30
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Feb 15, 2023
96ae529
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Feb 15, 2023
9f22de0
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Feb 15, 2023
9272076
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Feb 16, 2023
f14766f
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Feb 17, 2023
e06462c
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 2, 2023
adee82d
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 2, 2023
01b7989
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 2, 2023
aa474cb
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 3, 2023
4875767
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 7, 2023
4eee596
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 10, 2023
25117b2
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 13, 2023
1eec63f
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 13, 2023
72af18f
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 13, 2023
32a519a
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Mar 20, 2023
6562d75
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 3, 2023
2eee8fc
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 3, 2023
61d6e16
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 3, 2023
c4aa47a
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 4, 2023
5502b19
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 4, 2023
6101854
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 5, 2023
d45a198
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 6, 2023
ce75dec
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 12, 2023
bc7ed5c
Update on "nn.Linear: dispatch to bsr_dense_mm for half and bfloat16"
nikitaved Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6508,6 +6508,12 @@
SparseCPU: s_addmm_sparse_dense_cpu_
SparseCUDA: s_addmm_sparse_dense_cuda_

- func: _triton_bsr_dense_mm(Tensor bsr, Tensor dense) -> Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this to be a native function? You can define it directly in python if you never need it here.

Copy link
Collaborator Author

@nikitaved nikitaved Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we actually need it as it is being called inside addmm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you call it as at""_triton_bsr_dense_mm you do indeed. But since you already have the op, you could also call it via triton_kernel.call(...) and that doesn't require a native function.
Being a native function just gives you all the codegen niceties (namespaces, out= variant, etc). But if you don't need any of these, you most likely don't want it to be native as it adds more constraints (BC guarantees in particular).

This is not a big deal, just wanted to mention it in case you actually want the flexibility.

Copy link
Collaborator Author

@nikitaved nikitaved Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I am aware of the limitations and all the nice things (using out here, actually). But, unfortunately, linear itself is a native function, and so is addmm. We could indeed wrap these in Python and call into cpp implementations if needed, but that, I guess, if simple, could be done in a follow-up...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to follow your comment. How is it linked to linear and addmm implementations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @bdhirsh , @albanD , I will give it a shot.

Copy link
Collaborator Author

@nikitaved nikitaved Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh , whether the schema includes an out or not, I get the following failure:

================================================================================================================== FAILURES ==================================================================================================================
________________________________________________________________________________ TestSparseCSRCUDA.test_triton_bsr_dense_bmm_block_size_32_int32_cuda_float16 ________________________________________________________________________________
Traceback (most recent call last):
  File "/home/nik/git/Quansight/pytorch/test/test_sparse_csr.py", line 1527, in test_triton_bsr_dense_bmm
    res_tri = torch.ops.triton._triton_bsr_dense_mm_out(bsr, dense.transpose(-2, -1), out=res_tri_out)
  File "/home/nik/git/Quansight/pytorch/torch/_ops.py", line 527, in __call__  
    return self._op(*args, **kwargs or {})                                                                             
ValueError: vector::reserve                                                                                                                                                                                                                   ____________________________________

This issue is non-deterministic, and only manifests itself when several tests are being run in a group, say, with pytest -sv test/test_sparse_csr.py -k triton. I was not able to fail any individual test, however.

EDIT: apparently, some CIs even get segfaults...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems... very strange (although the error seems more related to triton than to the dispatcher?).

@nikitaved What if you try making the dispatcher op itself not an out= op, and you implement it by allocating a tensor manually inside of the op?

def triton_bsr_dense_mm_impl(bsr, dense):
    out = torch.empty(....)
    my_triton_kernel[grid](bsr, dense, out)
    return out

triton_lib.impl("_triton_bsr_dense_mm", triton_bsr_dense_mm_impl)

That error looks pretty cryptic - I see a similar error from google search in jax around "passing a static arg by keyword"

Copy link
Collaborator Author

@nikitaved nikitaved Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh , unfortunately, same issue:

diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
index 1baa7cffe66..39fc7597a87 100644
--- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
@@ -91,13 +91,13 @@ Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& s
    && n % blocksize[0] == 0) {
     try {
       const auto triton_kernel = c10::Dispatcher::singleton()
-        .findSchemaOrThrow("triton::_triton_bsr_dense_mm_out", "")
-        .typed<Tensor&(const Tensor&, const Tensor&, Tensor&)>();
+        .findSchemaOrThrow("triton::_triton_bsr_dense_mm", "")
+        .typed<Tensor(const Tensor&, const Tensor&)>();
       // Call Triton only if dispatch key was overwritten.
       // This is not strictly necessary since the definition is done in Python,
       // but we leave it here for extra safety.
       if (triton_kernel.hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
-        return triton_kernel.call(compressed, strided, result);
+        return result.copy_(triton_kernel.call(compressed, strided));
       }
     } catch (const std::exception& e) {
       // The schema is not defined and/or the key is not overwritten,
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 77c32175a9c..c900028d384 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -1481,13 +1481,13 @@ class TestSparseCSR(TestCase):
         kernel_invoked = [False]
 
         lib = torch.library.Library("triton", "DEF")
-        lib.define("_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)")
+        lib.define("_triton_bsr_dense_mm(Tensor bsr, Tensor dense) -> Tensor")
 
-        def impl(*args, **kwargs):
+        def impl(bsr, dense):
             kernel_invoked[0] = True
-            return bsr_dense_mm(*args, skip_checks=True, **kwargs)
+            return bsr_dense_mm(bsr, dense, skip_checks=True)
 
-        lib.impl("triton::_triton_bsr_dense_mm_out", impl, "SparseCsrCUDA")
+        lib.impl("triton::_triton_bsr_dense_mm", impl, "SparseCsrCUDA")
 
         # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
         tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
@@ -1523,9 +1523,7 @@ class TestSparseCSR(TestCase):
                 # Otherwise check correctness against bmm
                 # since nn.linear does not support bsr.dim() > 2.
                 res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
-                res_tri_out = torch.empty_like(res_dense)
-                res_tri = torch.ops.triton._triton_bsr_dense_mm_out(bsr, dense.transpose(-2, -1), out=res_tri_out)
-                self.assertTrue(res_tri is res_tri_out)
+                res_tri = torch.ops.triton._triton_bsr_dense_mm(bsr, dense.transpose(-2, -1))
             self.assertEqual(res_tri, res_dense)

________________________________________________________________________________ TestSparseCSRCUDA.test_triton_bsr_dense_bmm_block_size_16_int64_cuda_float16 ________________________________________________________________________________
Traceback (most recent call last):                                                                                                                                                                                                            
  File "/home/nik/git/Quansight/pytorch/test/test_sparse_csr.py", line 1526, in test_triton_bsr_dense_bmm                                                                                                                                     
    res_tri = torch.ops.triton._triton_bsr_dense_mm(bsr, dense.transpose(-2, -1))                                                                                                                                                             
  File "/home/nik/git/Quansight/pytorch/torch/_ops.py", line 527, in __call__                                                                                                                                                                 
    return self._op(*args, **kwargs or {})                                                                             
ValueError: vector::reserve            

I did not have such issues with the native function, however.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we investigate that further, or do that in a follow-up?

variants: function
dispatch:
CUDA: triton_bsr_dense_mm
autogen: _triton_bsr_dense_mm.out

- func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/native/sparse/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/SparseCsrTensorUtils.h>

// Required for checking whether Triton kernels are available
#include <torch/csrc/jit/frontend/function_schema_parser.h>
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
#include <ATen/core/dispatch/Dispatcher.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
Expand All @@ -12,6 +16,7 @@
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/_triton_bsr_dense_mm.h>
#endif

namespace at {
Expand Down Expand Up @@ -70,6 +75,31 @@ Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& s
blocksize = {values.size(-2), values.size(-1)};
}

// No stable support for ROCM in Triton yet.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there? we have a triton rocm binary as dependency for the install.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something might have changes since then, I will try it out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!
Can be in a follow up as well if this PR is already tricky enough!

#ifndef USE_ROCM
// Triton works only with blocksizes which are powers of 2.
const auto is_power_of_2 = [](int64_t v) -> bool {
return !(v & (v - 1));
};

// Dtype and blocksize checks for potential Triton usage.
if ((strided.scalar_type() == ScalarType::Half
|| strided.scalar_type() == ScalarType::BFloat16)
&& is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
&& (blocksize[0] >= 16) && (blocksize[1] >= 16)
// lhs is retiled to (b0, b1) while rhs is to (b1, b0),
// so the result is tiled to (b0, b0) and we need to make
// sure that dense.size(-1) is divisible by b0.
&& n % blocksize[0] == 0) {
const auto triton_kernel = c10::Dispatcher::singleton()
.findOp(torch::jit::parseName("aten::_triton_bsr_dense_mm"));
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
// Call Triton only if dispatch key was overwritten.
if (triton_kernel->hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
return at::_triton_bsr_dense_mm_out(result, compressed, strided);
}
}
#endif

// (..., r, c) -> (..., r / b0, c / b1, b0, b1)
// NOTE: this function ALWAYS creates a view upon successful execution.
const auto tile_tensor = [compressed_layout](
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/sparse/SparseMatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,5 @@ Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
return output;
}


} // namespace native
} // namespace at
7 changes: 7 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -639,5 +639,12 @@ Tensor _sparse_csr_prod_cuda(const Tensor& input, IntArrayRef dims_to_reduce, bo
return result;
}

Tensor triton_bsr_dense_mm(
const Tensor& bsr,
const Tensor& dense) {
TORCH_CHECK(false, "_triton_bsr_dense_mm: Triton kernel should be overwritten in Python.");
return Tensor {};
}

} // namespace native
} // namespace at
2 changes: 2 additions & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ aten::_transformer_encoder_layer_fwd
aten::_transformer_encoder_layer_fwd.out
aten::_trilinear
aten::_trilinear.out
aten::_triton_bsr_dense_mm
aten::_triton_bsr_dense_mm.out
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
aten::_triton_multi_head_attention
aten::_triton_multi_head_attention.out
aten::_triton_scaled_dot_attention
Expand Down
31 changes: 27 additions & 4 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,9 +1475,19 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from torch._inductor.utils import has_triton
from torch.sparse._triton_ops import bsr_dense_mm

if not has_triton():
if not has_triton() or bsr_dense_mm is None:
self.skipTest("Triton is not available.")

kernel_invoked = [False]

lib = torch.library.Library("aten", "IMPL")

def impl(*args, **kwargs):
kernel_invoked[0] = True
return bsr_dense_mm(*args, skip_checks=True, **kwargs)

lib.impl("aten::_triton_bsr_dense_mm", impl, "SparseCsrCUDA")

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

Expand All @@ -1499,16 +1509,29 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):

bsr = bsr.to_sparse_bsr(block_size)

res_tri = bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
if bsr.dim() == 2:
# Test against linear to check dispatch.
res_tri = torch.nn.functional.linear(dense, bsr)
res_dense = torch.nn.functional.linear(dense, bsr.to_dense())

# Check dispatch worked with non-trivial outputs
if m > 0 and n > 0 and k > 0:
self.assertTrue(kernel_invoked[0])
kernel_invoked[0] = False
else:
# Otherwise check correctness against bmm
# since nn.linear does not support bsr.dim() > 2.
res_tri = torch._triton_bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
self.assertEqual(res_tri, res_dense)

res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
# check whether bsr_dense_mm handles different grid sizes
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
res_tri = bsr_dense_mm(
res_tri = torch.sparse._triton_ops.bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
Expand Down