Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bsr_dense_bmm(): remove sparse_rowspace kernel and some dead code #100876

Closed
wants to merge 6 commits into from
3 changes: 1 addition & 2 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3389,12 +3389,11 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
for grid in grid_gen:
res_tri = torch.sparse._triton_ops.bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
is_sparse_rowspace_mode=is_sparse_rowspace
)
self.assertEqual(res_tri, res_dense)

Expand Down
238 changes: 7 additions & 231 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,6 @@ def _has_triton():
return False


def compressed_indices_to_plain_indices(cidx, pidx):
nnz = pidx.shape[-1]
cdim = cidx.shape[-1] - 1
batch_numel = cidx.shape[0]
batch_offset = torch.arange(batch_numel, dtype=cidx.dtype, device=cidx.device)[
:, None
]

cidx_batch_offsetted = cidx[:, :-1] + nnz * batch_offset
cidx_linear = torch.empty(
(batch_numel * cdim + 1,), dtype=cidx.dtype, device=cidx.device
)
cidx_linear[:-1] = cidx_batch_offsetted.reshape(-1)
cidx_linear[-1] = nnz * batch_numel

idx_linear = torch._convert_indices_from_csr_to_coo(
cidx_linear, pidx.reshape(-1), out_int32=(cidx.dtype == torch.int32)
).select(0, 0)

return idx_linear.reshape(batch_numel, -1).sub_(cdim * batch_offset)


def make_triton_contiguous(t):
# Triton does not distinguish between row- and col-majorness
# and will be fast as long as there is a contiguous dimension.
Expand Down Expand Up @@ -237,159 +215,6 @@ def _bsr_strided_dense_rowspace_kernel(
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))


@triton.jit
def _bsr_strided_sparse_rowspace_kernel(
BLOCKSIZE_ROW: tl.constexpr,
BLOCKSIZE_COL: tl.constexpr,
batch_idx_ptr,
row_idx_ptr,
nnz_per_row_ptr,
nnz_per_row_cumsum_ptr,
col_indices_ptr,
col_indices_stride,
# values prologue
values_ptr,
values_nnz_stride,
values_row_block_stride,
values_col_block_stride,
# values epilogue
# dense prologue
dense_ptr,
dense_batch_stride,
dense_tiled_row_stride,
dense_tiled_col_stride,
dense_row_block_stride,
dense_col_block_stride,
# dense epilogue
# output prologue
output_ptr,
output_batch_stride,
output_tiled_row_stride,
output_tiled_col_stride,
output_row_block_stride,
output_col_block_stride,
# output epilogue
GROUP_SIZE_ROW: tl.constexpr,
):
row_block_pid = tl.program_id(axis=0)
col_block_pid = tl.program_id(axis=1)
n_block_rows = tl.num_programs(axis=0)
n_block_cols = tl.num_programs(axis=1)

row_block_pid, col_block_pid = tl.swizzle2d(
row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
)

batch_idx = tl.load(batch_idx_ptr + row_block_pid)
row_idx = tl.load(row_idx_ptr + row_block_pid)
row_idx_nnz = tl.load(nnz_per_row_ptr + row_block_pid)
row_idx_nnz_cumsum = tl.load(nnz_per_row_cumsum_ptr + row_block_pid)
row_idx_nnz_offset = row_idx_nnz_cumsum - row_idx_nnz

row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
col_block_arange = tl.arange(0, BLOCKSIZE_COL)

# Pointers are set to the first block of the current row.
values_block_ptrs = (
values_ptr
+ values_nnz_stride * row_idx_nnz_offset
+ values_row_block_stride * row_block_arange[:, None]
+ values_col_block_stride * col_block_arange[None, :]
)

# NOTE: dense is advanced into all dimensions but the tiled row one.
# That will be advanced in the loop according to values in col_indices.
dense_block_ptrs = (
dense_ptr
+ dense_batch_stride * batch_idx
+ dense_tiled_col_stride * col_block_pid
+ dense_row_block_stride * col_block_arange[:, None]
+ dense_col_block_stride * row_block_arange[None, :]
)

# Pointers are set to exact write-to locations
output_ptrs = (
output_ptr
+ output_batch_stride * batch_idx
+ output_tiled_row_stride * row_idx
+ output_tiled_col_stride * col_block_pid
+ output_row_block_stride * row_block_arange[:, None]
+ output_col_block_stride * row_block_arange[None, :]
)

output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
col_index_nnz_ptr = col_indices_ptr + row_idx_nnz_offset * col_indices_stride
for _ in range(row_idx_nnz):
values_block = tl.load(values_block_ptrs)

# find which row of dense needs to get loaded
# for multiplication with values_block.
dense_row_idx = tl.load(col_index_nnz_ptr)
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)

# do block mm
output_acc_block += tl.dot(values_block, dense_block)

# move val/col_index ptrs to the next block in the row
values_block_ptrs += values_nnz_stride
col_index_nnz_ptr += col_indices_stride

# write back the result
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))


def _run_sparse_rowspace_kernel(
blocksize, values, crow_indices, col_indices, dense, output, max_grid
):
# Compute a vector of non-zero elements numbers per each row.
# We want to ultimately iterate over non-zero rows.
nnz_per_row = crow_indices[:, 1:] - crow_indices[:, :-1]

# Compute indices of non-zero counts.
# batch_idx maps to a broadcasted batch index, while
# row_idx tracks non-zero rows of the sparse argument
# and rows of the output that get modified.
batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True)

# Compress the vector of counts to hold only non-zero values.
nnz_per_row = nnz_per_row[batch_idx, row_idx]
# Compute cumulative counts which along with nnz_per_row
# are used to compute offsets into nnz values.
nnz_per_row_cumsum = nnz_per_row.cumsum(-1)

n_nnz_block_rows = row_idx.size(-1)
n_block_cols = dense.size(-3)

full_grid = (n_block_cols, n_nnz_block_rows)
if max_grid is not None:
grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))
else:
grid_blocks = None
tensor_dims_map = {
batch_idx: (None, 0),
row_idx: (None, 0),
nnz_per_row: (None, 0),
nnz_per_row_cumsum: (None, 0),
col_indices: (None, None),
values: (None, None),
dense: (-3, None),
output: (-3, None),
}

def kernel(grid, *sliced_tensors):
_bsr_strided_sparse_rowspace_kernel[grid](
*blocksize,
# First 4 tensors are contiguous, skip strides.
*sliced_tensors[:4],
*ptr_stride_extractor(*sliced_tensors[4:]),
GROUP_SIZE_ROW=4,
num_stages=1,
num_warps=4
)

launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)


def _run_dense_rowspace_kernel(
blocksize, values, crow_indices, col_indices, dense, output, max_grid
):
Expand Down Expand Up @@ -427,7 +252,6 @@ def bsr_dense_mm(
dense: torch.Tensor,
*,
skip_checks: bool = False,
is_sparse_rowspace_mode: Optional[bool] = None,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
out: Optional[torch.Tensor] = None,
):
Expand Down Expand Up @@ -522,10 +346,6 @@ def is_compatible_blocksize(b):
if bsr._nnz() == 0:
return out

# TODO: insert switch
if is_sparse_rowspace_mode is None:
is_sparse_rowspace_mode = False

# Introduce fake batch dimension if not present for convenience.
crow_indices = bsr.crow_indices().unsqueeze(0)
col_indices = bsr.col_indices().unsqueeze(0)
Expand All @@ -549,22 +369,12 @@ def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
crow_indices, batch_dims_broadcasted, (-1,)
)

if is_sparse_rowspace_mode:
# Flatten batch dimension with nnz dimension
# as required by the sparse rowspace kernel.
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted + (-1,), ()
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted + (values.shape[-3],), values.shape[-2:]
)
else:
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted, (-1,)
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted, values.shape[-3:]
)
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted, (-1,)
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted, values.shape[-3:]
)

dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])

Expand Down Expand Up @@ -597,43 +407,9 @@ def tile_to_blocksize(t, blocksize):
out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))

# Launch kernel
if is_sparse_rowspace_mode:
kernel = _run_sparse_rowspace_kernel
else:
kernel = _run_dense_rowspace_kernel

kernel = _run_dense_rowspace_kernel
kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)

return out_backup
else:
bsr_dense_mm = None # type: ignore[assignment]


if __name__ == "__main__":
from torch._inductor.utils import has_triton

if has_triton():
torch.manual_seed(13)
dtype = torch.float32
p = 0.5
mask_size = (8, 8)
block_size = (64, 64)
size = (mask_size[0] * block_size[0], mask_size[1] * block_size[1])

n_exp = 512
diff = torch.ones(n_exp, device="cuda", dtype=torch.float32)
for i in range(n_exp):
mask = torch.rand(*mask_size, device="cuda") < p
x = torch.rand(*mask_size, *block_size, dtype=dtype, device="cuda") / 10
x = (
(mask[:, :, None, None] * x)
.transpose(-3, -2)
.reshape(*size)
.to_sparse_bsr(*block_size)
)
y = torch.rand(5, *size, dtype=dtype, device="cuda") / 10
res_dense = x.to_dense() @ y
res = bsr_dense_mm(x, y)
diff[i] = (res - res_dense).abs().max()
print(f"mean: {diff.mean()}, std: {diff.std()}")
print(f"max diff: {diff.max()}")