Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use more performant bsr_scatter_mm within bsr_dense_mm when blocksize is 16. #111489

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compr
return indices_data


def bsr_scatter_mm(bsr, other, indices_data=None):
def bsr_scatter_mm(bsr, other, indices_data=None, out=None):
"""BSR @ strided -> strided
"""

Expand All @@ -717,11 +717,14 @@ def bsr_scatter_mm(bsr, other, indices_data=None):

indices_format = indices_data[0]

if out is None:
out = torch.empty((Ms, Ns), dtype=bsr.dtype, device=bsr.device)

if bsr._nnz() == 0:
result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device)
out.zero_()
elif indices_format in {'bsr_strided_mm_compressed', 'bsr_strided_mm'}:
result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device)
scatter_mm(bsr.values(), other, indices_data, accumulators=result)
out.zero_()
scatter_mm(bsr.values(), other, indices_data, accumulators=out)
elif indices_format == 'scatter_mm':
accumulators = torch.zeros((Ms // blocksize[0] * Ns // blocksize[0], blocksize[0], blocksize[0]),
dtype=bsr.dtype, device=bsr.device)
Expand All @@ -734,15 +737,15 @@ def bsr_scatter_mm(bsr, other, indices_data=None):

scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators)

result = (accumulators
out.copy_(accumulators
.unflatten(0, (Ms // blocksize[0], Ns // blocksize[0]))
.movedim((0, 1, 2, 3), (2, 0, 3, 1)) # equivalent to .transpose(0, 1).transpose(2, 3).transpose(1, 2)
.reshape(Ns, Ms)
.transpose(0, 1))
else:
raise NotImplementedError(indices_format)

return result
return out


if has_triton():
Expand Down Expand Up @@ -1218,6 +1221,9 @@ def bsr_dense_mm(

blocksize = bsr.values().shape[-2:]

if max(blocksize) == 16 and bsr.dense_dim() == 0 and bsr.ndim == 2:
return bsr_scatter_mm(bsr, dense, out=out)

# NOTE: out is contiguous, so prepare_inputs will create a view.
# out gets modified in-place, so we store a backup copy.
out_backup = out
Expand Down
Loading