Skip to content

Commit

Permalink
[core][pruning][sparse][feature] SparseSemiStructured tensor subclass
Browse files Browse the repository at this point in the history
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
    - a Tensor subclass, `SparseSemiStructuredTensor` to store the
      sparse tensor in copmressed form and override `__torch_dispatch__`.
    - a conversion function that takes in a dense tensor and a
      semi-structured sparse bool mask and creates an instance of the
      subclass.

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](#103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
refleect these changes.

ops supported:
ghstack-source-id: bef558914004d60a149964afcad7a0b87ae199df
Pull Request resolved: #102135
  • Loading branch information
jcaip committed Jun 17, 2023
1 parent 22e8a61 commit 9801742
Show file tree
Hide file tree
Showing 5 changed files with 944 additions and 2 deletions.
245 changes: 245 additions & 0 deletions benchmarks/sparse/benchmark_semi_structured_sparsity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import random
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from tqdm import tqdm
import pandas as pd
import argparse
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor


torch.set_printoptions(
precision=2,
threshold=None,
edgeitems=16,
linewidth=480,
profile=None,
sci_mode=False,
)


# helper model definition for pruner
class Model(nn.Module):
def __init__(self, m, k, dtype=None):
super().__init__()
# transposed so reversed
self.linear = nn.Linear(k, m)

def forward(self, x):
return self.linear(x)


def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""

choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]

return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)


def test_linear(m, k, n, dtype, contiguous, backend):
SparseSemiStructuredTensor.fuse_transpose = contiguous
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
input_tensor = torch.zeros(n, k).to(dtype).cuda()
model = Model(m, k).to(dtype).cuda().eval()

dense_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()

dense_output = model(input_tensor)

# sparsify weights
model.linear.weight = nn.Parameter(to_sparse_semi_structured(sparse_weight))

sparse_output = model(input_tensor)

sparse_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()

correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)

return {
"test_function": "linear",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}


def test_tensor(m, k, n, dtype, contiguous, backend):
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
B = torch.zeros(k, n).to(dtype).cuda()
bias = torch.rand(n).to(dtype).cuda()

sA = to_sparse_semi_structured(A)

# torch.mm calculation
if dtype is not torch.int8:
dense_output = torch.mm(A, B)

dense_measurement = benchmark.Timer(
stmt="torch.mm(A, B)",
globals=locals(),
).blocked_autorange()

else:
print("int8 baseline not supported")
dense_output = torch.mm(sA, B)

dense_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()

sparse_output = torch.mm(sA, B)
sparse_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()

correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)

return {
"test_function": "tensor",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}


if __name__ == "__main__":
dtype_lookup = {
"int8": torch.int8,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}

parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
parser.add_argument(
"--mode",
type=str,
choices=[
"nvidia-bert",
"nvidia-fixed-k",
"nvidia-fixed-mn",
],
)
parser.add_argument(
"--dtype",
type=str,
choices=dtype_lookup.keys(),
default="fp16",
)
parser.add_argument(
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
)
parser.add_argument("-contiguous", action="store_true")
parser.add_argument("-e2e", action="store_true")
parser.add_argument("-save", action="store_true")
args = parser.parse_args()

if args.e2e:
eval_fn = test_linear
else:
eval_fn = test_tensor

print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
dtype = dtype_lookup[args.dtype]

if args.mode == "nvidia-bert":
bert_shapes = [
(3072, 1024, 16384),
(4096, 1024, 16384),
(1024, 1024, 16384),
(1024, 4096, 16384),
]
results = (
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
for (m, k, n) in tqdm(bert_shapes)
)

elif args.mode == "nvidia-fixed-k":
mn_vals = [
3072,
4096,
5120,
6144,
7168,
8192,
9216,
10240,
11264,
12288,
13312,
14336,
15360,
16384,
17408,
18432,
19456,
20480,
]
results = (
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
for mn in tqdm(mn_vals)
)

elif args.mode == "nvidia-fixed-mn":
k_vals = [
2560,
3840,
5120,
6400,
7680,
8960,
10240,
11520,
12800,
14080,
15360,
16640,
17920,
19200,
20480,
]
results = (
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
for k in tqdm(k_vals)
)

df = pd.DataFrame.from_records(results)
if args.save:
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
df.to_csv(save_file)
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
print(df)

0 comments on commit 9801742

Please sign in to comment.