diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index 4f9f05fc0680..3c9badbaabfd 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -702,7 +702,7 @@ def bsr_scatter_mm_indices_data(bsr, other, indices_format='bsr_strided_mm_compr return indices_data -def bsr_scatter_mm(bsr, other, indices_data=None): +def bsr_scatter_mm(bsr, other, indices_data=None, out=None): """BSR @ strided -> strided """ @@ -717,11 +717,14 @@ def bsr_scatter_mm(bsr, other, indices_data=None): indices_format = indices_data[0] + if out is None: + out = torch.empty((Ms, Ns), dtype=bsr.dtype, device=bsr.device) + if bsr._nnz() == 0: - result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device) + out.zero_() elif indices_format in {'bsr_strided_mm_compressed', 'bsr_strided_mm'}: - result = torch.zeros((Ms, Ns), dtype=bsr.dtype, device=bsr.device) - scatter_mm(bsr.values(), other, indices_data, accumulators=result) + out.zero_() + scatter_mm(bsr.values(), other, indices_data, accumulators=out) elif indices_format == 'scatter_mm': accumulators = torch.zeros((Ms // blocksize[0] * Ns // blocksize[0], blocksize[0], blocksize[0]), dtype=bsr.dtype, device=bsr.device) @@ -734,7 +737,7 @@ def bsr_scatter_mm(bsr, other, indices_data=None): scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators) - result = (accumulators + out.copy_(accumulators .unflatten(0, (Ms // blocksize[0], Ns // blocksize[0])) .movedim((0, 1, 2, 3), (2, 0, 3, 1)) # equivalent to .transpose(0, 1).transpose(2, 3).transpose(1, 2) .reshape(Ns, Ms) @@ -742,7 +745,7 @@ def bsr_scatter_mm(bsr, other, indices_data=None): else: raise NotImplementedError(indices_format) - return result + return out if has_triton(): @@ -1218,6 +1221,9 @@ def bsr_dense_mm( blocksize = bsr.values().shape[-2:] + if max(blocksize) == 16 and bsr.dense_dim() == 0 and bsr.ndim == 2: + return bsr_scatter_mm(bsr, dense, out=out) + # NOTE: out is contiguous, so prepare_inputs will create a view. # out gets modified in-place, so we store a backup copy. out_backup = out