21
21
quantize_ ,
22
22
autoquant ,
23
23
)
24
+ from torchao .sparsity import (
25
+ sparsify_ ,
26
+ semi_sparse_weight ,
27
+ )
24
28
25
29
torch ._inductor .config .force_fuse_int_mm_with_mul = True
26
30
torch ._inductor .config .fx_graph_cache = True
@@ -40,10 +44,10 @@ def format_value(value):
40
44
41
45
print (tabulate (main_table , headers = ['Task' , 'Metrics' ], tablefmt = 'grid' ))
42
46
43
- def run_evaluation (repo_id , tasks , limit , device , precision , quantization , compile , save , batch_size , max_length ):
47
+ def run_evaluation (repo_id , tasks , limit , device , precision , quantization , sparsity , compile , save , batch_size , max_length ):
44
48
45
49
tokenizer = AutoTokenizer .from_pretrained (repo_id )
46
- model = AutoModelForCausalLM .from_pretrained (repo_id ).to (device = "cpu" , dtype = precision )
50
+ model = AutoModelForCausalLM .from_pretrained (repo_id ).to (dtype = precision , device = device )
47
51
48
52
if quantization == "autoquant" and compile :
49
53
model = torch .compile (model , mode = "max-autotune" , fullgraph = True )
@@ -61,6 +65,24 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
61
65
if quantization != "autoquant" and compile :
62
66
model = torch .compile (model , mode = "max-autotune" , fullgraph = True )
63
67
68
+ if sparsity == "semi_sparse" :
69
+ def all_linear (mod , name ):
70
+ if isinstance (mod , torch .nn .Linear ) and "lm_head" not in name :
71
+ return True
72
+ return False
73
+ torch .sparse .semi_structured ._FORCE_CUTLASS = False
74
+ sparsify_ (model , semi_sparse_weight (), filter_fn = all_linear )
75
+ elif sparsity == "semi_sparse_mlp_only" :
76
+ def all_linear (mod , name ):
77
+ if isinstance (mod , torch .nn .Linear ) and "lm_head" not in name and "mlp" in name :
78
+ return True
79
+ return False
80
+ torch .sparse .semi_structured ._FORCE_CUTLASS = False
81
+ sparsify_ (model , semi_sparse_weight (), filter_fn = all_linear )
82
+
83
+ if sparsity and compile :
84
+ model = torch .compile (model , mode = "max-autotune" , fullgraph = True )
85
+
64
86
with torch .no_grad ():
65
87
result = evaluate (
66
88
HFLM (
@@ -90,10 +112,11 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
90
112
parser .add_argument ('--precision' , type = lambda x : getattr (torch , x .split ("." )[- 1 ]), default = torch .bfloat16 , help = 'dtype precision to use' )
91
113
parser .add_argument ('--device' , type = str , default = "cuda" , help = 'Device to use for evaluation' )
92
114
parser .add_argument ('-q' , '--quantization' , default = "None" , choices = ["int8dq" , "int8wo" , "int4wo" ,"autoquant" , "None" ], help = 'Which quantization technique to apply' )
115
+ parser .add_argument ('-s' , '--sparsity' , default = "None" , choices = ["semi_sparse" , "semi_sparse_mlp_only" , "None" ], help = 'Which sparsity technique to apply' )
93
116
parser .add_argument ('--compile' , action = 'store_true' , help = 'Whether to compile the model.' )
94
117
parser .add_argument ('--save' , action = 'store_true' , help = 'Whether to save the model.' )
95
118
parser .add_argument ('--batch_size' , type = int , default = 1 , help = 'Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes' )
96
119
parser .add_argument ('--max_length' , type = int , default = None , help = 'Length of text to process at one time' )
97
120
98
121
args = parser .parse_args ()
99
- run_evaluation (args .repo_id , args .tasks , args .limit , args .device , args .precision , args .quantization , args .compile , args .save , args .batch_size , args .max_length )
122
+ run_evaluation (args .repo_id , args .tasks , args .limit , args .device , args .precision , args .quantization , args .sparsity , args . compile , args .save , args .batch_size , args .max_length )
0 commit comments