From 57e5875bff16a9659ab2a01f3a195163b5d8c3e3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 3 Jun 2021 19:23:27 +0200 Subject: [PATCH] Add fine-grained benchmarks for sddmm (#144) Use configurations from real models --- benchmarks/benchmark_sddmm.py | 112 ++++++++++++++++++++++++++ xformers/components/attention/core.py | 4 +- 2 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 benchmarks/benchmark_sddmm.py diff --git a/benchmarks/benchmark_sddmm.py b/benchmarks/benchmark_sddmm.py new file mode 100644 index 000000000..7ca3c27f4 --- /dev/null +++ b/benchmarks/benchmark_sddmm.py @@ -0,0 +1,112 @@ +import itertools + +import torch +from torch.utils import benchmark + +from xformers.components.attention._sputnik_sparse import _csr_to_coo +from xformers.components.attention.core import SparseCS, _create_random_sparsity + +MIN_RUN_TIME = 0.2 + + +def _get_fn(backend): + if backend == "csr_ge": + fn = torch.ops.xformers.csr_sddmm + elif backend == "csr_sputnik": + fn = torch.ops.xformers.sddmm_sputnik + elif backend == "coo_ge": + + def fn(a, b, row_indices, row_offsets, column_indices): + row_coo, _ = _csr_to_coo( + a.shape[-2], b.shape[-2], row_offsets, column_indices + ) + return torch.ops.xformers.coo_sddmm( + a, b, row_indices, row_coo, column_indices + ) + + elif backend == "csr_to_coo": + + def fn(a, b, row_indices, row_offsets, column_indices): + row_coo, _ = _csr_to_coo( + a.shape[-2], b.shape[-2], row_offsets, column_indices + ) + return row_coo + + return fn + + +def bench_sddmm(configs): + min_run_time = MIN_RUN_TIME + + device = torch.device("cuda") + results = [] + + for (B, M, K), prob in configs: + a = torch.rand(B, M, K, device=device) + b = torch.rand(B, M, K, device=device) + + mask = _create_random_sparsity( + torch.ones(1, M, M, dtype=torch.bool), prob, divisible_by=16 + ) + aa = a + bb = b + mask = SparseCS(mask, device) + row_indices = mask.row_indices + row_offsets = mask.row_offsets + column_indices = mask.column_indices + + for backend in ["csr_sputnik", "csr_ge", "coo_ge", "csr_to_coo"]: + + fn_str = "fn(a, b, row_indices, row_offsets, column_indices)" + fn = _get_fn(backend) + + results.append( + benchmark.Timer( + stmt=fn_str, + globals={ + "a": aa, + "b": bb, + "mask": mask, + "row_indices": row_indices, + "row_offsets": row_offsets, + "column_indices": column_indices, + "fn": fn, + }, + label="sddmm", + sub_label=f"B={B:>4d}, M={M:>4d}, K={K:>3d}, prob={prob:0.4f}", + description=backend, + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + return results + + +# batch size 32, for different layers +SWIN_T_SIZES = [(96, 3136, 32), (192, 784, 32), (384, 196, 32), (768, 49, 32)] +swin_t_config = list(zip(SWIN_T_SIZES, (0.9844, 0.9375, 0.75, 0.0))) + +# some random values +BASIC_SIZES = [(32, 1024, 32), (32, 1024, 128), (8, 4096, 32), (8, 4096, 128)] +SPARSITIES = [0.90, 0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.999] +basic_config = list(itertools.product(BASIC_SIZES, SPARSITIES)) + +# batch size 32 here +vit_sizes = [ + (192, 785, 64), # deit_small_patch8_224 + (192, 197, 64), # deit_small_patch16_224 + (384, 785, 64), # deit_base_patch8_224 + (384, 197, 64), # deit_base_patch16_224 +] +SPARSITIES = [0.70, 0.80, 0.85, 0.90, 0.93, 0.95, 0.97] +vit_config = list(itertools.product(vit_sizes, SPARSITIES)) + +results = [] + +print("Swin Transformer") +results += bench_sddmm(swin_t_config) +print("ViT") +results += bench_sddmm(vit_config) +print("Basic cases") +results += bench_sddmm(basic_config) diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index fa8f4409a..fbca766c4 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -6,13 +6,13 @@ from ._sputnik_sparse import SparseCS -def _create_random_sparsity(matrix, sparsity): +def _create_random_sparsity(matrix, sparsity, divisible_by=4): assert matrix.ndim == 3 keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity nonzero = torch.nonzero(keep) nnz = nonzero.shape[0] # NOTE: need to make it a multiple of 4 for sputnik - nonzero = nonzero[: (nnz - nnz % 4)] + nonzero = nonzero[: (nnz - nnz % divisible_by)] i, j = nonzero.unbind(1) output = torch.zeros_like(matrix) bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]