Skip to content

Commit

Permalink
Use more performant bsr_scatter_mm within bsr_dense_mm when blocksize…
Browse files Browse the repository at this point in the history
… is 16. (pytorch#111489)

Pull Request resolved: pytorch#111489
Approved by: https://github.com/cpuhrsch
ghstack dependencies: pytorch#110396, pytorch#111470
  • Loading branch information
pearu authored and xuhancn committed Nov 8, 2023
1 parent e1a4751 commit 4b96070
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compr
return indices_data


def bsr_scatter_mm(bsr, other, indices_data=None):
def bsr_scatter_mm(bsr, other, indices_data=None, out=None):
"""BSR @ strided -> strided
"""

Expand All @@ -717,11 +717,14 @@ def bsr_scatter_mm(bsr, other, indices_data=None):

indices_format = indices_data[0]

if out is None:
out = torch.empty((Ms, Ns), dtype=bsr.dtype, device=bsr.device)

if bsr._nnz() == 0:
result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device)
out.zero_()
elif indices_format in {'bsr_strided_mm_compressed', 'bsr_strided_mm'}:
result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device)
scatter_mm(bsr.values(), other, indices_data, accumulators=result)
out.zero_()
scatter_mm(bsr.values(), other, indices_data, accumulators=out)
elif indices_format == 'scatter_mm':
accumulators = torch.zeros((Ms // blocksize[0] * Ns // blocksize[0], blocksize[0], blocksize[0]),
dtype=bsr.dtype, device=bsr.device)
Expand All @@ -734,15 +737,15 @@ def bsr_scatter_mm(bsr, other, indices_data=None):

scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators)

result = (accumulators
out.copy_(accumulators
.unflatten(0, (Ms // blocksize[0], Ns // blocksize[0]))
.movedim((0, 1, 2, 3), (2, 0, 3, 1)) # equivalent to .transpose(0, 1).transpose(2, 3).transpose(1, 2)
.reshape(Ns, Ms)
.transpose(0, 1))
else:
raise NotImplementedError(indices_format)

return result
return out


if has_triton():
Expand Down Expand Up @@ -1218,6 +1221,9 @@ def bsr_dense_mm(

blocksize = bsr.values().shape[-2:]

if max(blocksize) == 16 and bsr.dense_dim() == 0 and bsr.ndim == 2:
return bsr_scatter_mm(bsr, dense, out=out)

# NOTE: out is contiguous, so prepare_inputs will create a view.
# out gets modified in-place, so we store a backup copy.
out_backup = out
Expand Down

0 comments on commit 4b96070

Please sign in to comment.