Skip to content

Commit

Permalink
Add improved heuristics for using COO kernels (facebookresearch#145)
Browse files Browse the repository at this point in the history
* Add improved heuristics for using COO kernels

* Update vision_transformer notebook with new numbers
  • Loading branch information
fmassa committed Jun 3, 2021
1 parent c0b0d56 commit c8571e0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
15 changes: 8 additions & 7 deletions docs/source/vision_transformers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -553,17 +553,18 @@
"Forward only\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7fbea8033d00>\n",
"profile\n",
" Median: 464.69 ms\n",
" IQR: 0.38 ms (464.41 to 464.80)\n",
" 5 measurements, 1 runs per measurement, 1 thread\n",
"Memory used: 8020.283203125 MB\n",
" Median: 194.43 ms\n",
" IQR: 0.94 ms (193.73 to 194.67)\n",
" 11 measurements, 1 runs per measurement, 1 thread\n",
"Memory used: 8022.40283203125 MB\n",
"\n",
"Forward + backward\n",
"<torch.utils.benchmark.utils.common.Measurement object at 0x7fbdf04af4f0>\n",
"profile\n",
" Median: 1.17 s\n",
" 2 measurements, 1 runs per measurement, 1 thread\n",
"Memory used: 8205.25732421875 MB\n"
" Median: 633.81 ms\n",
" IQR: 3.66 ms (632.41 to 636.07)\n",
" 4 measurements, 1 runs per measurement, 1 thread\n",
"Memory used: 8207.640625 MB\n"
]
}
],
Expand Down
32 changes: 31 additions & 1 deletion xformers/components/attention/_sputnik_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,39 @@ def _csr_to_coo(m, n, row_offsets, column_indices):
return row_coo, column_indices


def _should_use_coo(a, sparsity):
if not a.is_cuda:
return False
B, M, K = a.shape
# amortize overhead of converting from csr to coo
if B < 32 and M < 4096:
return False
if sparsity > 0.995:
return False
if sparsity < 0.9:
return False
if K > 64:
return False
# let's be overly cautious here for now
return sparsity > 0.97


def _should_use_csr_ge(a, sparsity):
if not a.is_cuda:
return False
return sparsity > 0.99


def _sddmm_func(a, b, row_indices, row_offsets, column_indices):
sparsity = 1 - column_indices.shape[0] / (a.shape[1] * b.shape[1])
if sparsity > 0.99 and a.is_cuda:
if _should_use_coo(a, sparsity):
m = a.shape[-2]
n = b.shape[-2]
# converting from csr to coo has a constant overhead of ~150us
# so only dispatch to it for reasonably large problem sizes
ro, ci = _csr_to_coo(m, n, row_offsets, column_indices)
return torch.ops.xformers.coo_sddmm(a, b, row_indices, ro, ci)
elif _should_use_csr_ge(a, sparsity):
return torch.ops.xformers.csr_sddmm(
a, b, row_indices, row_offsets, column_indices
)
Expand Down

0 comments on commit c8571e0

Please sign in to comment.