Skip to content

Commit

Permalink
bsr_dense_bmm(): remove sparse_rowspace kernel and some dead code
Browse files Browse the repository at this point in the history
ghstack-source-id: 5d0f0b1a20325ae4e0084a149401abf44fbd9d15
Pull Request resolved: #100876
  • Loading branch information
nikitaved committed May 8, 2023
1 parent 6f8766f commit 8f50e5b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 233 deletions.
3 changes: 1 addition & 2 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3389,12 +3389,11 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
for grid in grid_gen:
res_tri = torch.sparse._triton_ops.bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
is_sparse_rowspace_mode=is_sparse_rowspace
)
self.assertEqual(res_tri, res_dense)

Expand Down
238 changes: 7 additions & 231 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,6 @@ def _has_triton():
return False


def compressed_indices_to_plain_indices(cidx, pidx):
nnz = pidx.shape[-1]
cdim = cidx.shape[-1] - 1
batch_numel = cidx.shape[0]
batch_offset = torch.arange(batch_numel, dtype=cidx.dtype, device=cidx.device)[
:, None
]

cidx_batch_offsetted = cidx[:, :-1] + nnz * batch_offset
cidx_linear = torch.empty(
(batch_numel * cdim + 1,), dtype=cidx.dtype, device=cidx.device
)
cidx_linear[:-1] = cidx_batch_offsetted.reshape(-1)
cidx_linear[-1] = nnz * batch_numel

idx_linear = torch._convert_indices_from_csr_to_coo(
cidx_linear, pidx.reshape(-1), out_int32=(cidx.dtype == torch.int32)
).select(0, 0)

return idx_linear.reshape(batch_numel, -1).sub_(cdim * batch_offset)


def make_triton_contiguous(t):
# Triton does not distinguish between row- and col-majorness
# and will be fast as long as there is a contiguous dimension.
Expand Down Expand Up @@ -237,159 +215,6 @@ def _bsr_strided_dense_rowspace_kernel(
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))


@triton.jit
def _bsr_strided_sparse_rowspace_kernel(
BLOCKSIZE_ROW: tl.constexpr,
BLOCKSIZE_COL: tl.constexpr,
batch_idx_ptr,
row_idx_ptr,
nnz_per_row_ptr,
nnz_per_row_cumsum_ptr,
col_indices_ptr,
col_indices_stride,
# values prologue
values_ptr,
values_nnz_stride,
values_row_block_stride,
values_col_block_stride,
# values epilogue
# dense prologue
dense_ptr,
dense_batch_stride,
dense_tiled_row_stride,
dense_tiled_col_stride,
dense_row_block_stride,
dense_col_block_stride,
# dense epilogue
# output prologue
output_ptr,
output_batch_stride,
output_tiled_row_stride,
output_tiled_col_stride,
output_row_block_stride,
output_col_block_stride,
# output epilogue
GROUP_SIZE_ROW: tl.constexpr,
):
row_block_pid = tl.program_id(axis=0)
col_block_pid = tl.program_id(axis=1)
n_block_rows = tl.num_programs(axis=0)
n_block_cols = tl.num_programs(axis=1)

row_block_pid, col_block_pid = tl.swizzle2d(
row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
)

batch_idx = tl.load(batch_idx_ptr + row_block_pid)
row_idx = tl.load(row_idx_ptr + row_block_pid)
row_idx_nnz = tl.load(nnz_per_row_ptr + row_block_pid)
row_idx_nnz_cumsum = tl.load(nnz_per_row_cumsum_ptr + row_block_pid)
row_idx_nnz_offset = row_idx_nnz_cumsum - row_idx_nnz

row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
col_block_arange = tl.arange(0, BLOCKSIZE_COL)

# Pointers are set to the first block of the current row.
values_block_ptrs = (
values_ptr
+ values_nnz_stride * row_idx_nnz_offset
+ values_row_block_stride * row_block_arange[:, None]
+ values_col_block_stride * col_block_arange[None, :]
)

# NOTE: dense is advanced into all dimensions but the tiled row one.
# That will be advanced in the loop according to values in col_indices.
dense_block_ptrs = (
dense_ptr
+ dense_batch_stride * batch_idx
+ dense_tiled_col_stride * col_block_pid
+ dense_row_block_stride * col_block_arange[:, None]
+ dense_col_block_stride * row_block_arange[None, :]
)

# Pointers are set to exact write-to locations
output_ptrs = (
output_ptr
+ output_batch_stride * batch_idx
+ output_tiled_row_stride * row_idx
+ output_tiled_col_stride * col_block_pid
+ output_row_block_stride * row_block_arange[:, None]
+ output_col_block_stride * row_block_arange[None, :]
)

output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
col_index_nnz_ptr = col_indices_ptr + row_idx_nnz_offset * col_indices_stride
for _ in range(row_idx_nnz):
values_block = tl.load(values_block_ptrs)

# find which row of dense needs to get loaded
# for multiplication with values_block.
dense_row_idx = tl.load(col_index_nnz_ptr)
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)

# do block mm
output_acc_block += tl.dot(values_block, dense_block)

# move val/col_index ptrs to the next block in the row
values_block_ptrs += values_nnz_stride
col_index_nnz_ptr += col_indices_stride

# write back the result
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))


def _run_sparse_rowspace_kernel(
blocksize, values, crow_indices, col_indices, dense, output, max_grid
):
# Compute a vector of non-zero elements numbers per each row.
# We want to ultimately iterate over non-zero rows.
nnz_per_row = crow_indices[:, 1:] - crow_indices[:, :-1]

# Compute indices of non-zero counts.
# batch_idx maps to a broadcasted batch index, while
# row_idx tracks non-zero rows of the sparse argument
# and rows of the output that get modified.
batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True)

# Compress the vector of counts to hold only non-zero values.
nnz_per_row = nnz_per_row[batch_idx, row_idx]
# Compute cumulative counts which along with nnz_per_row
# are used to compute offsets into nnz values.
nnz_per_row_cumsum = nnz_per_row.cumsum(-1)

n_nnz_block_rows = row_idx.size(-1)
n_block_cols = dense.size(-3)

full_grid = (n_block_cols, n_nnz_block_rows)
if max_grid is not None:
grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))
else:
grid_blocks = None
tensor_dims_map = {
batch_idx: (None, 0),
row_idx: (None, 0),
nnz_per_row: (None, 0),
nnz_per_row_cumsum: (None, 0),
col_indices: (None, None),
values: (None, None),
dense: (-3, None),
output: (-3, None),
}

def kernel(grid, *sliced_tensors):
_bsr_strided_sparse_rowspace_kernel[grid](
*blocksize,
# First 4 tensors are contiguous, skip strides.
*sliced_tensors[:4],
*ptr_stride_extractor(*sliced_tensors[4:]),
GROUP_SIZE_ROW=4,
num_stages=1,
num_warps=4
)

launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)


def _run_dense_rowspace_kernel(
blocksize, values, crow_indices, col_indices, dense, output, max_grid
):
Expand Down Expand Up @@ -427,7 +252,6 @@ def bsr_dense_mm(
dense: torch.Tensor,
*,
skip_checks: bool = False,
is_sparse_rowspace_mode: Optional[bool] = None,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
out: Optional[torch.Tensor] = None,
):
Expand Down Expand Up @@ -522,10 +346,6 @@ def is_compatible_blocksize(b):
if bsr._nnz() == 0:
return out

# TODO: insert switch
if is_sparse_rowspace_mode is None:
is_sparse_rowspace_mode = False

# Introduce fake batch dimension if not present for convenience.
crow_indices = bsr.crow_indices().unsqueeze(0)
col_indices = bsr.col_indices().unsqueeze(0)
Expand All @@ -549,22 +369,12 @@ def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
crow_indices, batch_dims_broadcasted, (-1,)
)

if is_sparse_rowspace_mode:
# Flatten batch dimension with nnz dimension
# as required by the sparse rowspace kernel.
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted + (-1,), ()
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted + (values.shape[-3],), values.shape[-2:]
)
else:
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted, (-1,)
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted, values.shape[-3:]
)
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted, (-1,)
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted, values.shape[-3:]
)

dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])

Expand Down Expand Up @@ -597,43 +407,9 @@ def tile_to_blocksize(t, blocksize):
out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))

# Launch kernel
if is_sparse_rowspace_mode:
kernel = _run_sparse_rowspace_kernel
else:
kernel = _run_dense_rowspace_kernel

kernel = _run_dense_rowspace_kernel
kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)

return out_backup
else:
bsr_dense_mm = None # type: ignore[assignment]


if __name__ == "__main__":
from torch._inductor.utils import has_triton

if has_triton():
torch.manual_seed(13)
dtype = torch.float32
p = 0.5
mask_size = (8, 8)
block_size = (64, 64)
size = (mask_size[0] * block_size[0], mask_size[1] * block_size[1])

n_exp = 512
diff = torch.ones(n_exp, device="cuda", dtype=torch.float32)
for i in range(n_exp):
mask = torch.rand(*mask_size, device="cuda") < p
x = torch.rand(*mask_size, *block_size, dtype=dtype, device="cuda") / 10
x = (
(mask[:, :, None, None] * x)
.transpose(-3, -2)
.reshape(*size)
.to_sparse_bsr(*block_size)
)
y = torch.rand(5, *size, dtype=dtype, device="cuda") / 10
res_dense = x.to_dense() @ y
res = bsr_dense_mm(x, y)
diff[i] = (res - res_dense).abs().max()
print(f"mean: {diff.mean()}, std: {diff.std()}")
print(f"max diff: {diff.max()}")

0 comments on commit 8f50e5b

Please sign in to comment.