Skip to content

Commit eaf2908

Browse files
authored
Add semi-structured sparsity to hf eval (#576)
* Add hf example for semi-structured sparsity Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * updated notebook * update * update hf example * Update version.txt * update hf_eval changes * update * remove notebook and add script
1 parent 7e69ee3 commit eaf2908

File tree

3 files changed

+173
-5
lines changed

3 files changed

+173
-5
lines changed

benchmarks/benchmark_semi_sparse_training.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch.nn.functional as F
1717
from torch.utils import benchmark
1818

19+
from torch.sparse import to_sparse_semi_structured
1920
from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
2021
from torchao.sparsity.training.autograd import semi_structured_sparsify
2122

@@ -118,6 +119,18 @@ def fw(self):
118119
def bw(self):
119120
self.out.backward(self.grad, retain_graph=True)
120121

122+
class SemiSparseLinearOfflineCompressionTest(torch.nn.Module):
123+
def __init__(self, mkn):
124+
super().__init__()
125+
m, k, n = mkn
126+
self.model = torch.nn.Linear(k, n).cuda().half()
127+
self.model.weight = torch.nn.Parameter(to_sparse_semi_structured(self.model.weight))
128+
self.input = torch.randn([m, k], device='cuda', dtype=torch.half, requires_grad=True)
129+
self.grad = torch.randn([m, n], device="cuda", dtype=torch.half)
130+
131+
def fw(self):
132+
self.out = self.model(self.input)
133+
121134
class SemiSparseLinearTest(LinearTest):
122135
def __init__(self, mkn):
123136
super().__init__(mkn)
@@ -170,8 +183,8 @@ def __init__(self, model_type, batch_size):
170183

171184
if __name__ == "__main__":
172185
print("BENCHMARKING")
173-
parser = argparse.ArgumentParser(description='run semi-structured spares training benchmarks')
174-
parser.add_argument('--mode', type=str, choices=["linear", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit")
186+
parser = argparse.ArgumentParser(description='run semi-structured sparse training benchmarks')
187+
parser.add_argument('--mode', type=str, choices=["linear", "llama3-8b", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit")
175188
parser.add_argument('--save', action="store_true", help="save benchmarking results")
176189
args = parser.parse_args()
177190
if args.mode == "linear":
@@ -198,6 +211,34 @@ def __init__(self, model_type, batch_size):
198211
bw=True,
199212
cuda_graph=True,
200213
blocked_autorange=True)
214+
elif args.mode == "llama3-8b":
215+
functions = {
216+
"dense_linear": LinearTest,
217+
"semi_sparse_linear": SemiSparseLinearOfflineCompressionTest,
218+
}
219+
batch_size = 16
220+
cases = list(
221+
product_dict(
222+
mkn=[
223+
# attn q and o
224+
(batch_size, 4096, 4096),
225+
# attn k and v
226+
(batch_size, 4096, 1024),
227+
# mlp up and gate
228+
(batch_size, 4096, 14336),
229+
# mlp down
230+
(batch_size, 14336, 4096),
231+
],
232+
)
233+
)
234+
235+
df = benchmark_helper(
236+
functions,
237+
cases,
238+
fw=True,
239+
bw=False,
240+
cuda_graph=True,
241+
blocked_autorange=True)
201242

202243
elif args.mode == "vit":
203244
functions = {

scripts/hf_eval.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
quantize_,
2222
autoquant,
2323
)
24+
from torchao.sparsity import (
25+
sparsify_,
26+
semi_sparse_weight,
27+
)
2428

2529
torch._inductor.config.force_fuse_int_mm_with_mul = True
2630
torch._inductor.config.fx_graph_cache = True
@@ -40,10 +44,10 @@ def format_value(value):
4044

4145
print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))
4246

43-
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, save, batch_size, max_length):
47+
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, sparsity, compile, save, batch_size, max_length):
4448

4549
tokenizer = AutoTokenizer.from_pretrained(repo_id)
46-
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
50+
model = AutoModelForCausalLM.from_pretrained(repo_id).to(dtype=precision, device=device)
4751

4852
if quantization == "autoquant" and compile:
4953
model = torch.compile(model, mode="max-autotune", fullgraph=True)
@@ -61,6 +65,24 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
6165
if quantization != "autoquant" and compile:
6266
model = torch.compile(model, mode="max-autotune", fullgraph=True)
6367

68+
if sparsity == "semi_sparse":
69+
def all_linear(mod, name):
70+
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
71+
return True
72+
return False
73+
torch.sparse.semi_structured._FORCE_CUTLASS = False
74+
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)
75+
elif sparsity == "semi_sparse_mlp_only":
76+
def all_linear(mod, name):
77+
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name and "mlp" in name:
78+
return True
79+
return False
80+
torch.sparse.semi_structured._FORCE_CUTLASS = False
81+
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)
82+
83+
if sparsity and compile:
84+
model = torch.compile(model, mode="max-autotune", fullgraph=True)
85+
6486
with torch.no_grad():
6587
result = evaluate(
6688
HFLM(
@@ -90,10 +112,11 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
90112
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
91113
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
92114
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
115+
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
93116
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
94117
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
95118
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
96119
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
97120

98121
args = parser.parse_args()
99-
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.save, args.batch_size, args.max_length)
122+
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.sparsity, args.compile, args.save, args.batch_size, args.max_length)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# This script shows how to accelerate an off-the-shelf 2:4 sparse checkpoint
2+
# using pytorch's `to_sparse_semi_structured`
3+
4+
# It takes advantage of the model checkpoints offered by neuralmagic:
5+
# https://huggingface.co/nm-testing/SparseLlama-3-8B-pruned_50.2of4-FP8
6+
7+
import os
8+
import torch
9+
from torchao.sparsity import sparsify_, semi_sparse_weight
10+
11+
from tqdm import tqdm
12+
from transformers import AutoModelForCausalLM, AutoTokenizer
13+
14+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
15+
16+
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True
17+
torch.set_float32_matmul_precision('high')
18+
19+
def timed(fn):
20+
start = torch.cuda.Event(enable_timing=True)
21+
end = torch.cuda.Event(enable_timing=True)
22+
start.record()
23+
result = fn()
24+
end.record()
25+
torch.cuda.synchronize()
26+
return result, start.elapsed_time(end) / 1000
27+
28+
29+
def benchmark(fn, WARMUP=5, N=25):
30+
time_per_batch = []
31+
with torch.no_grad():
32+
# warmup steps
33+
for _ in range(WARMUP):
34+
timed(fn)
35+
36+
# benchmark
37+
for _ in tqdm(range(N)):
38+
with torch.no_grad():
39+
_ , time_sec = timed(fn)
40+
time_per_batch.append(time_sec)
41+
42+
# each time we generate 128 tokens - 7 for the prompt = 121 tokens at a time.
43+
total_time = sum(time_per_batch)
44+
tokens_per_second = 121 * N / total_time
45+
print(f"Total time: {total_time:.3f}s | Tokens/second: {tokens_per_second:.3f}")
46+
47+
# define model and tokenizer
48+
model = AutoModelForCausalLM.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4", torch_dtype=torch.float16).cuda()
49+
tokenizer = AutoTokenizer.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4")
50+
51+
# Even though we need to pad the matmul shapes from (1, hidden) @ (hidden, output)
52+
# to (8, hidden) @ (hidden, output) we are still able to achieve speedups on
53+
# the mlp.up and mlp.gate linear layers of the FFN.
54+
def is_mlp_up_or_mlp_gate(mod, name):
55+
return isinstance(mod, torch.nn.Linear) and ('mlp.gate' in name or 'mlp.up' in name)
56+
57+
# apply sparsity
58+
sparsify_(model, semi_sparse_weight(), filter_fn=is_mlp_up_or_mlp_gate)
59+
60+
# Specify the max length (including both the prompt and the response)
61+
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
62+
# with sequence length = `max_length`. The longer the more you will re-use it
63+
model.generation_config.max_length = 128
64+
model.generation_config.pad_token_id = tokenizer.eos_token_id
65+
model.generation_config.cache_implementation = "static"
66+
67+
prompt = "Why dogs are so cute?"
68+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
69+
70+
# without `torch.compile`: each call takes ~ 5.0 seconds (on A100 80G + torch 2.3)
71+
# Total time: 168.715s | Tokens/second: 17.930
72+
outputs = model.generate(**inputs)
73+
response = tokenizer.batch_decode(outputs)[0]
74+
print(response)
75+
76+
# `torch.compile(model, ...)` is not recommended as you compile callbacks
77+
# and full generate. We recommend compiling only the forward for now.
78+
# "reduce-overhead" will use cudagraphs.
79+
torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit = None
80+
81+
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
82+
83+
benchmark(lambda: model.generate(**inputs))
84+
85+
# sanity check we get same output as non-compiled model
86+
outputs = model.generate(**inputs)
87+
response = tokenizer.batch_decode(outputs)[0]
88+
print(response)
89+
90+
## Run torch.compile baseline
91+
92+
del model
93+
model = AutoModelForCausalLM.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4", torch_dtype=torch.float16).cuda()
94+
95+
model.generation_config.max_length = 128
96+
model.generation_config.pad_token_id = tokenizer.eos_token_id
97+
model.generation_config.cache_implementation = "static"
98+
99+
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
100+
benchmark(lambda: model.generate(**inputs))
101+
102+
outputs = model.generate(**inputs)
103+
response = tokenizer.batch_decode(outputs)[0]
104+
print(response)

0 commit comments

Comments
 (0)