Skip to content

Commit

Permalink
Eliminate unnecessary multiplications by 1 in addmm with sparse compr…
Browse files Browse the repository at this point in the history
…essed tensor operand (#114026)

This PR:
- updates `torch/sparse/_triton_ops_meta.py` for the API change in `triton.testing.do_bench`
- force `num_stages` to be 1 when blocksize is 128x128 to avoid out of resources exception when `bsr_dense_mm` is called from `nn.linear`.
- as in the title. The performance of `nn.linear` on BSR tensor weights (dtypes `float16` and `bfloat16`) is increased as follows (`NVIDIA A100-SXM4-80GB`):
  - for blocksize 16x16, the average/maximum speed up is about 11/20 %
  - for blocksize 32x32, the average/maximum speed up is about 15/24 %
  - for blocksize 64x64, the average/maximum speed up is about 18/26 %
  - for blocksize 128x128, the average/maximum speed up is about 15/28 %

Pull Request resolved: #114026
Approved by: https://github.com/cpuhrsch
  • Loading branch information
pearu authored and pytorchmergebot committed Nov 19, 2023
1 parent 826ab0e commit 12f95df
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 79 deletions.
31 changes: 23 additions & 8 deletions aten/src/ATen/native/sparse/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,27 +218,42 @@ Tensor& _compressed_row_strided_addmm_out(
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
auto alpha_val = alpha.toComplexDouble();
auto beta_val = beta.toComplexDouble();
// If result is not the same as self, it could always be used as out argument to mm.
if (!result.is_same(self)) {
_compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);

_compressed_row_strided_mm_out(mat1, mat2, result);
if (alpha_val != 1.) {
result.mul_(alpha);
}
// Process beta
if (beta.toComplexDouble() != 0.) {
result.add_(self.mul(beta));
if (beta_val != 0.) {
if (beta_val == 1.) {
result.add_(self);
} else {
result.add_(self.mul(beta));
}
}
}
// Otherwise we need to allocate external memory for mm if beta != 0.
else {
// Process beta
if (beta.toComplexDouble() != 0.) {
result.mul_(beta);
if (beta_val != 0.) {
if (beta_val != 1.) {
result.mul_(beta);
}
auto mm = at::empty_like(result);
_compressed_row_strided_mm_out(mat1, mat2, mm);
mm.mul_(alpha);
if (alpha_val != 1.) {
mm.mul_(alpha);
}
result.add_(mm);
}
else {
_compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
_compressed_row_strided_mm_out(mat1, mat2, result);
if (alpha_val != 1.) {
result.mul_(alpha);
}
}
}

Expand Down
Loading

0 comments on commit 12f95df

Please sign in to comment.