Skip to content

Conversation

tahsintunan
Copy link
Contributor

@tahsintunan tahsintunan commented Sep 5, 2025

Purpose

Extends QuantFP8 to support per-token-group quantization and adds a torch implementation of the QuantFP8 group quantization.

Addresses #24185

Changes

  • Add is_per_tensor(), is_per_token(), is_per_group() helper methods to GroupShape
  • Extend QuantFP8 to support arbitrary group sizes like GroupShape(1, 128)
  • Added test_fp8_quant_group.py and benchmark_quantfp8_group.py

Test Plan

Tested with existing test suite:

  • tests/kernels/quantization/test_fp8_quant.py
  • tests/kernels/quantization/test_fp8_quant_group.py
  • Verified that per-tensor, per-token, and per-group quantization modes work correctly

Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully refactors FP8 quantization by extending QuantFP8 to support per-token-group quantization, unifying the quantization paths. The changes are well-structured, improving code clarity and maintainability. I've identified one potential issue where a group size of 1 is not handled correctly, which could lead to a runtime error. My review includes a suggestion to fix this.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
@tahsintunan tahsintunan marked this pull request as draft September 5, 2025 18:49
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is ready for review yet but this is the right start to address #20711 and also #24185!

A big piece that's missing is addressing #24185: we should have a torch implementation of group quant. If you want to leave that to somebody else to implement that's okay too!

@tahsintunan
Copy link
Contributor Author

Hey @ProExpertProg , it's not ready for review yet. Mainly opened this draft PR for early feedback. I'll address the comments and also add the torch implementation

@ProExpertProg
Copy link
Collaborator

Okay please keep us posted, really looking forward to this work!

Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
@mergify mergify bot added the performance Performance-related issues label Sep 7, 2025
@tahsintunan tahsintunan marked this pull request as ready for review September 8, 2025 01:05
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moving in the right direction, but we should still try to move towards using QuantFP8 instead of a bunch of free function calls around. If you want to address that in a separate PR that's fine but then we shouldn't touch the MoE layers in this one at all and just add support for group quant to QuantFP8 here.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly this refactor is going to be a massive can of worms but it has to be done. All of these free functions need to become classes. We can't pass QuantFP8 instances through.

@tahsintunan
Copy link
Contributor Author

@ProExpertProg understood. Then I’ll just revert the latest changes and keep only the torch impl in this PR. The refactoring seems more involved than I had imagined. This will also keep the PR short and easy to review.

Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
@ProExpertProg
Copy link
Collaborator

@tahsintunan can you post performance numbers for this? We might need to enable the custom path by default if the torch path is not fast enough

@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Sep 11, 2025

Speedup over Torch (Compiled)
 col_major group_shape     CUDA   Triton
     False    (-1, -1) 5.188913      inf
     False     (1, -1) 0.928436      inf
     False     (1, 64) 0.636413 0.264822
     False    (1, 128) 0.851472 0.414274
      True     (1, 64) 0.654376 0.293654
      True    (1, 128) 0.910548 0.471288
B200 results
QuantFP8 performance:
     hidden_size  batch_size  col_major group_shape  Torch (Compiled)         CUDA        Triton
0              1           1      False    (-1, -1)          2.883091    12.171749      0.000000
1             16           1      False    (-1, -1)          7.367291    13.908626      0.000000
2             64           1      False    (-1, -1)         12.385259    14.125941      0.000000
3            128           1      False    (-1, -1)         19.442213    14.427254      0.000000
4            256           1      False    (-1, -1)         32.477913    14.494847      0.000000
5            512           1      False    (-1, -1)         57.917011    16.473432      0.000000
6           1024           1      False    (-1, -1)        110.735627    18.549175      0.000000
7           2048           1      False    (-1, -1)        210.799133    28.107615      0.000000
8           4096           1      False    (-1, -1)        417.107831    47.828042      0.000000
9              1          16      False    (-1, -1)          7.372255    13.896552      0.000000
10            16          16      False    (-1, -1)         32.464888    14.503759      0.000000
11            64          16      False    (-1, -1)        110.742910    18.602913      0.000000
12           128          16      False    (-1, -1)        210.841858    28.089034      0.000000
13           256          16      False    (-1, -1)        416.955492    47.807567      0.000000
14           512          16      False    (-1, -1)       1246.435165   100.707671      0.000000
15          1024          16      False    (-1, -1)       2532.836505   219.299457      0.000000
16          2048          16      False    (-1, -1)       5073.546727   425.454264      0.000000
17          4096          16      False    (-1, -1)      10333.471775  1023.561022      0.000000
18             1          32      False    (-1, -1)          9.229187    13.885136      0.000000
19            16          32      False    (-1, -1)         57.914544    16.429869      0.000000
20            64          32      False    (-1, -1)        210.633479    28.114413      0.000000
21           128          32      False    (-1, -1)        416.955492    47.903653      0.000000
22           256          32      False    (-1, -1)       1248.450089   100.709837      0.000000
23           512          32      False    (-1, -1)       2533.972604   218.814915      0.000000
24          1024          32      False    (-1, -1)       5073.210716   425.471658      0.000000
25          2048          32      False    (-1, -1)      10333.040237  1023.609037      0.000000
26          4096          32      False    (-1, -1)      20645.311356  2024.583975      0.000000
27             1          64      False    (-1, -1)         12.379711    14.163193      0.000000
28            16          64      False    (-1, -1)        110.958742    18.575772      0.000000
29            64          64      False    (-1, -1)        416.910959    48.109178      0.000000
30           128          64      False    (-1, -1)       1251.018651   100.739310      0.000000
31           256          64      False    (-1, -1)       2534.432002   218.634367      0.000000
32           512          64      False    (-1, -1)       5073.498726   425.464008      0.000000
33          1024          64      False    (-1, -1)      10338.255882  1023.516531      0.000000
34          2048          64      False    (-1, -1)      20644.271851  2024.601380      0.000000
35          4096          64      False    (-1, -1)      41266.143799  4024.530570      0.000000
36             1         128      False    (-1, -1)         19.457667    14.396398      0.000000
37            16         128      False    (-1, -1)        210.838531    28.155110      0.000000
38            64         128      False    (-1, -1)       1242.978160   100.497325      0.000000
39           128         128      False    (-1, -1)       2534.605707   219.093453      0.000000
40           256         128      False    (-1, -1)       5073.034604   425.403471      0.000000
41           512         128      False    (-1, -1)      10335.391998  1023.513048      0.000000
42          1024         128      False    (-1, -1)      20648.096085  2024.623950      0.000000
43          2048         128      False    (-1, -1)      41268.096924  4024.717331      0.000000
44          4096         128      False    (-1, -1)       8275.111675  8024.319967      0.000000
45             1           1      False     (1, -1)          3.134108     4.060020      0.000000
46            16           1      False     (1, -1)          3.427785     4.277234      0.000000
47            64           1      False     (1, -1)          3.520462     4.386433      0.000000
48           128           1      False     (1, -1)          3.609057     4.489328      0.000000
49           256           1      False     (1, -1)          4.421333     5.036036      0.000000
50           512           1      False     (1, -1)          5.965148     6.602702      0.000000
51          1024           1      False     (1, -1)          8.923592    10.081602      0.000000
52          2048           1      False     (1, -1)         15.122080    15.959404      0.000000
53          4096           1      False     (1, -1)         28.560640    29.943732      0.000000
54             1          16      False     (1, -1)          3.437647     4.305638      0.000000
55            16          16      False     (1, -1)          4.416621     5.032333      0.000000
56            64          16      False     (1, -1)          8.946080    10.084736      0.000000
57           128          16      False     (1, -1)         15.127093    15.956857      0.000000
58           256          16      False     (1, -1)         28.503059    29.912759      0.000000
59           512          16      False     (1, -1)         62.006400    63.806423      0.000000
60          1024          16      False     (1, -1)        122.140799   124.085050      0.000000
61          2048          16      False     (1, -1)        241.689558   244.872914      0.000000
62          4096          16      False     (1, -1)        674.907613   681.216407      0.000000
63             1          32      False     (1, -1)          3.472305     4.315280      0.000000
64            16          32      False     (1, -1)          5.970980     6.595621      0.000000
65            64          32      False     (1, -1)         15.113600    15.950691      0.000000
66           128          32      False     (1, -1)         28.523089    29.903118      0.000000
67           256          32      False     (1, -1)         62.005386    63.812394      0.000000
68           512          32      False     (1, -1)        122.147629   124.088054      0.000000
69          1024          32      False     (1, -1)        241.692217   244.890334      0.000000
70          2048          32      False     (1, -1)        674.927592   681.224394      0.000000
71          4096          32      False     (1, -1)       1346.823978  1357.751989      0.000000
72             1          64      False     (1, -1)          3.526970     4.387588      0.000000
73            16          64      False     (1, -1)          8.937098    10.084751      0.000000
74            64          64      False     (1, -1)         28.604990    29.990388      0.000000
75           128          64      False     (1, -1)         61.903499    63.724211      0.000000
76           256          64      False     (1, -1)        122.232325   124.162003      0.000000
77           512          64      False     (1, -1)        241.673470   244.848408      0.000000
78          1024          64      False     (1, -1)        674.841619   681.260395      0.000000
79          2048          64      False     (1, -1)       1346.858406  1357.757568      0.000000
80          4096          64      False     (1, -1)       2691.024017  2709.640026      0.000000
81             1         128      False     (1, -1)          3.605126     4.490015      0.000000
82            16         128      False     (1, -1)         15.113505    15.956237      0.000000
83            64         128      False     (1, -1)         62.000323    63.801745      0.000000
84           128         128      False     (1, -1)        122.064168   124.038590      0.000000
85           256         128      False     (1, -1)        241.680418   244.889308      0.000000
86           512         128      False     (1, -1)        674.919200   681.264400      0.000000
87          1024         128      False     (1, -1)       1346.860027  1357.774401      0.000000
88          2048         128      False     (1, -1)       2690.929604  2709.835243      0.000000
89          4096         128      False     (1, -1)       5376.534271  5411.731148      0.000000
90             1           1       True     (1, 64)          2.763574     3.034534      2.843692
91            16           1       True     (1, 64)          4.485459     3.360267      3.157333
92            64           1       True     (1, 64)          4.638418     3.624858      4.616264
93           128           1       True     (1, 64)          4.761500     4.147604      6.776809
94           256           1       True     (1, 64)          5.471333     5.721745     11.125818
95           512           1       True     (1, 64)          6.551305     8.487796     19.830472
96          1024           1       True     (1, 64)          9.289565    14.399479     37.414587
97          2048           1       True     (1, 64)         14.391912    25.775028     72.494545
98          4096           1       True     (1, 64)         26.188800    50.029220    144.374143
99             1          16       True     (1, 64)          4.491404     3.391747      3.130383
100           16          16       True     (1, 64)          5.460396     5.707412     11.117843
101           64          16       True     (1, 64)          9.313032    14.416736     37.431817
102          128          16       True     (1, 64)         14.407560    25.772228     72.499678
103          256          16       True     (1, 64)         26.124585    49.950143    144.322203
104          512          16       True     (1, 64)         55.314238   108.841251    291.141254
105         1024          16       True     (1, 64)        107.397452   216.629575    579.961861
106         2048          16       True     (1, 64)        210.952398   430.700790   1158.461178
107         4096          16       True     (1, 64)        612.862926  1052.128004   2509.212017
108            1          32       True     (1, 64)          4.554939     3.442185      3.467605
109           16          32       True     (1, 64)          6.555915     8.487949     19.836612
110           64          32       True     (1, 64)         14.390330    25.777934     72.491218
111          128          32       True     (1, 64)         26.200851    50.026226    144.356278
112          256          32       True     (1, 64)         55.381849   108.857566    291.198090
113          512          32       True     (1, 64)        107.496791   216.879459    580.223532
114         1024          32       True     (1, 64)        210.949333   430.684090   1158.512901
115         2048          32       True     (1, 64)        612.821080  1052.111999   2509.225965
116         4096          32       True     (1, 64)       1220.242791  2100.859642   5015.915871
117            1          64       True     (1, 64)          4.640990     3.627710      4.601067
118           16          64       True     (1, 64)          9.301277    14.401637     37.422345
119           64          64       True     (1, 64)         26.184000    50.024022    144.369103
120          128          64       True     (1, 64)         55.397333   108.856229    291.203826
121          256          64       True     (1, 64)        107.441274   216.847098    580.218820
122          512          64       True     (1, 64)        210.967398   430.737771   1158.439524
123         1024          64       True     (1, 64)        612.847286  1052.199654   2509.089947
124         2048          64       True     (1, 64)       1220.352712  2100.958564   5015.500069
125         4096          64       True     (1, 64)       2435.514190  4197.494507  10028.480053
126            1         128       True     (1, 64)          4.756784     4.148005      6.747685
127           16         128       True     (1, 64)         14.403556    25.767069     72.494712
128           64         128       True     (1, 64)         55.505582   108.979418    291.173010
129          128         128       True     (1, 64)        107.228561   216.699322    580.036724
130          256         128       True     (1, 64)        210.961998   430.722152   1158.431053
131          512         128       True     (1, 64)        612.867269  1052.191320   2508.980036
132         1024         128       True     (1, 64)       1220.253198  2101.005034   5015.872002
133         2048         128       True     (1, 64)       2435.470494  4197.084808  10028.879642
134         4096         128       True     (1, 64)       4975.279999  8390.416145  20054.080009
135            1           1      False     (1, 64)          2.762939     3.026017      2.850835
136           16           1      False     (1, 64)          2.970894     3.286393      3.116082
137           64           1      False     (1, 64)          3.079456     3.553044      4.595833
138          128           1      False     (1, 64)          3.204202     4.065799      6.740989
139          256           1      False     (1, 64)          3.912238     5.604339     11.124909
140          512           1      False     (1, 64)          5.043629     7.930898     19.832728
141         1024           1      False     (1, 64)          7.921280    13.187439     37.406577
142         2048           1      False     (1, 64)         13.305237    23.488458     72.477510
143         4096           1      False     (1, 64)         25.334297    45.476465    144.326158
144            1          16      False     (1, 64)          2.971765     3.316764      3.122925
145           16          16      False     (1, 64)          3.918891     5.584376     11.110261
146           64          16      False     (1, 64)          7.939918    13.205882     37.428651
147          128          16      False     (1, 64)         13.298080    23.488266     72.490606
148          256          16      False     (1, 64)         25.242139    45.407283    144.281244
149          512          16      False     (1, 64)         55.477978    98.384418    291.165622
150         1024          16      False     (1, 64)        109.047316   195.192896    579.942114
151         2048          16      False     (1, 64)        215.345359   387.919674   1158.402836
152         4096          16      False     (1, 64)        621.794128   967.087364   2509.188056
153            1          32      False     (1, 64)          2.999146     3.360393      3.446298
154           16          32      False     (1, 64)          5.043040     7.924601     19.809549
155           64          32      False     (1, 64)         13.304490    23.483716     72.490968
156          128          32      False     (1, 64)         25.351674    45.472836    144.342253
157          256          32      False     (1, 64)         55.448247    98.423913    291.184482
158          512          32      False     (1, 64)        109.250355   195.436285    580.162357
159         1024          32      False     (1, 64)        215.357850   387.921581   1158.435765
160         2048          32      False     (1, 64)        621.742566   967.111053   2509.225965
161         4096          32      False     (1, 64)       1240.242087  1930.636009   5015.887976
162            1          64      False     (1, 64)          3.072762     3.553000      4.596042
163           16          64      False     (1, 64)          7.937796    13.190251     37.413803
164           64          64      False     (1, 64)         25.305347    45.485774    144.355202
165          128          64      False     (1, 64)         55.459518    98.418832    291.171102
166          256          64      False     (1, 64)        109.258610   195.414715    580.223981
167          512          64      False     (1, 64)        215.344790   387.894402   1158.416916
168         1024          64      False     (1, 64)        621.843931   967.107849   2509.143949
169         2048          64      False     (1, 64)       1240.365194  1930.536032   5016.236067
170         4096          64      False     (1, 64)       2476.624055  3856.629372  10028.928280
171            1         128      False     (1, 64)          3.209663     4.061244      6.744589
172           16         128      False     (1, 64)         13.308408    23.484122     72.473600
173           64         128      False     (1, 64)         55.608907    98.541654    291.173252
174          128         128      False     (1, 64)        109.015318   195.239761    580.039529
175          256         128      False     (1, 64)        215.332869   387.912636   1158.418824
176          512         128      False     (1, 64)        621.805869   967.174416   2509.175897
177         1024         128      False     (1, 64)       1240.225419  1930.506627   5016.344070
178         2048         128      False     (1, 64)       2476.455255  3856.664022  10029.567719
179         4096         128      False     (1, 64)       4948.124695  7708.298683  20053.744316
180            1           1       True    (1, 128)          2.758435     3.043615      2.862564
181           16           1       True    (1, 128)          4.687816     3.360487      3.082453
182           64           1       True    (1, 128)          4.806833     3.485621      3.549333
183          128           1       True    (1, 128)          4.876295     3.537790      4.668267
184          256           1       True    (1, 128)          5.417000     4.259854      6.941365
185          512           1       True    (1, 128)          6.561702     5.844019     11.408527
186         1024           1       True    (1, 128)          8.844835     9.290385     20.606741
187         2048           1       True    (1, 128)         13.603574    15.730864     38.884781
188         4096           1       True    (1, 128)         24.508309    29.680863     76.973999
189            1          16       True    (1, 128)          4.654515     3.389630      3.083699
190           16          16       True    (1, 128)          5.410435     4.250747      6.946697
191           64          16       True    (1, 128)          8.844215     9.302109     20.635326
192          128          16       True    (1, 128)         13.597890    15.728170     38.860225
193          256          16       True    (1, 128)         24.463149    29.647506     76.932181
194          512          16       True    (1, 128)         51.645074    67.246371    156.497100
195         1024          16       True    (1, 128)         98.864721   132.336001    310.735233
196         2048          16       True    (1, 128)        193.024938   261.698594    620.206987
197         4096          16       True    (1, 128)        575.487041   714.297939   1432.509005
198            1          32       True    (1, 128)          4.720646     3.384399      3.160629
199           16          32       True    (1, 128)          6.589617     5.844182     11.409422
200           64          32       True    (1, 128)         13.612622    15.727933     38.853933
201          128          32       True    (1, 128)         24.503304    29.681223     76.953454
202          256          32       True    (1, 128)         51.589389    67.193798    156.517457
203          512          32       True    (1, 128)         98.974946   132.435554    311.055486
204         1024          32       True    (1, 128)        193.039614   261.693619    620.181145
205         2048          32       True    (1, 128)        575.506248   714.307476   1432.536006
206         4096          32       True    (1, 128)       1146.398697  1424.682115   2862.574100
207            1          64       True    (1, 128)          4.808653     3.486566      3.536356
208           16          64       True    (1, 128)          8.852301     9.289909     20.614112
209           64          64       True    (1, 128)         24.528000    29.737940     76.993302
210          128          64       True    (1, 128)         51.584862    67.195727    156.485455
211          256          64       True    (1, 128)         98.999913   132.440394    311.061087
212          512          64       True    (1, 128)        193.000862   261.681505    620.164133
213         1024          64       True    (1, 128)        575.446396   714.263787   1432.569027
214         2048          64       True    (1, 128)       1146.431999  1424.560848   2862.714052
215         4096          64       True    (1, 128)       2287.967975  2844.389280   5722.555876
216            1         128       True    (1, 128)          4.885333     3.536037      4.676622
217           16         128       True    (1, 128)         13.604716    15.722652     38.857116
218           64         128       True    (1, 128)         51.715833    67.393788    156.562606
219          128         128       True    (1, 128)         98.904182   132.288668    310.808379
220          256         128       True    (1, 128)        193.009881   261.704651    620.221496
221          512         128       True    (1, 128)        575.487351   714.327348   1432.586968
222         1024         128       True    (1, 128)       1146.414170  1424.435365   2862.849951
223         2048         128       True    (1, 128)       2287.958145  2844.369782   5722.423792
224         4096         128       True    (1, 128)       4566.914717  5683.860064  11441.159725
225            1           1      False    (1, 128)          2.765305     3.044451      2.853063
226           16           1      False    (1, 128)          3.220044     3.284058      3.075064
227           64           1      False    (1, 128)          3.309514     3.390130      3.531404
228          128           1      False    (1, 128)          3.406923     3.450224      4.652833
229          256           1      False    (1, 128)          3.946000     4.128372      6.932044
230          512           1      False    (1, 128)          5.103674     5.725761     11.398578
231         1024           1      False    (1, 128)          7.211200     8.609532     20.590400
232         2048           1      False    (1, 128)         11.854333    14.519183     38.856909
233         4096           1      False    (1, 128)         22.307041    27.416357     76.930134
234            1          16      False    (1, 128)          3.232615     3.311902      3.104344
235           16          16      False    (1, 128)          3.948167     4.125146      6.911830
236           64          16      False    (1, 128)          7.222465     8.623327     20.616870
237          128          16      False    (1, 128)         11.852206    14.519950     38.847461
238          256          16      False    (1, 128)         22.295834    27.390898     76.899306
239          512          16      False    (1, 128)         49.187169    61.770049    156.497410
240         1024          16      False    (1, 128)         95.324130   121.285744    310.716690
241         2048          16      False    (1, 128)        187.868772   239.616783    620.237935
242         4096          16      False    (1, 128)        566.639695   670.621267   1432.497978
243            1          32      False    (1, 128)          3.245624     3.310083      3.150945
244           16          32      False    (1, 128)          5.119360     5.723191     11.404182
245           64          32      False    (1, 128)         11.849796    14.518216     38.842607
246          128          32      False    (1, 128)         22.305292    27.424782     76.940182
247          256          32      False    (1, 128)         49.139918    61.770921    156.477429
248          512          32      False    (1, 128)         95.543319   121.374891    311.054472
249         1024          32      False    (1, 128)        187.874202   239.607905    620.220369
250         2048          32      False    (1, 128)        566.629245   670.680581   1432.590961
251         4096          32      False    (1, 128)       1129.735360  1337.570381   2862.668037
252            1          64      False    (1, 128)          3.314980     3.388164      3.536653
253           16          64      False    (1, 128)          7.212735     8.611956     20.597099
254           64          64      False    (1, 128)         22.354400    27.455734     76.961758
255          128          64      False    (1, 128)         49.142905    61.770176    156.480172
256          256          64      False    (1, 128)         95.530608   121.370670    311.042801
257          512          64      False    (1, 128)        187.858049   239.621327    620.180653
258         1024          64      False    (1, 128)        566.666659   670.606823   1432.489991
259         2048          64      False    (1, 128)       1129.783374  1337.552023   2862.748027
260         4096          64      False    (1, 128)       2255.399411  2670.203209   5722.807884
261            1         128      False    (1, 128)          3.400404     3.449124      4.650247
262           16         128      False    (1, 128)         11.853120    14.515769     38.839272
263           64         128      False    (1, 128)         49.261225    61.902635    156.534682
264          128         128      False    (1, 128)         95.339283   121.236208    310.799493
265          256         128      False    (1, 128)        187.863189   239.609483    620.258510
266          512         128      False    (1, 128)        566.599210   670.548299   1432.632983
267         1024         128      False    (1, 128)       1129.825262  1337.493610   2862.669945
268         2048         128      False    (1, 128)       2255.497859  2670.243168   5723.012209
269         4096         128      False    (1, 128)       4566.255887  5332.851219  11441.559792

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a few nits! I think we can actually extend the current benchmark instead, I can either push it to your branch or open a new PR - either way you should remove the benchmark you added


if padded_dim != hidden_dim:
padding = padded_dim - hidden_dim
x = F.pad(x, (0, padding), mode='constant', value=0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's a way to do this without padding - I worry the generated Triton kernel won't be able to eliminate the copy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm realizing the padding is unlikely anyway as group shape will likely divide the hidden_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, padding will only be used for non-standard dimensions, which should be rare.

@ProExpertProg ProExpertProg mentioned this pull request Sep 11, 2025
@ProExpertProg
Copy link
Collaborator

@tahsintunan do you have a timeline for implementing the feedback? Based on the performance numbers we want to merge this as soon as feasible. I'd also be happy to finish your work and merge (you'll still be a coauthor on the commit) - let me know! Thanks for working on this

@tahsintunan
Copy link
Contributor Author

@ProExpertProg I can take care of the remaining items by Monday. If you need this merged sooner, feel free to wrap it up. If not, I'll get it done by Monday.

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
@ProExpertProg
Copy link
Collaborator

@tahsintunan yeah today or early tomorrow is totally fine! Thanks for taking this on.

Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 15, 2025
@mgoin
Copy link
Member

mgoin commented Sep 16, 2025

I ran the benchmark on H100 and found it was slower for 1x128 (the case we care about)

python benchmarks/kernels/bench_per_token_quant_fp8.py
INFO 09-15 22:12:41 [__init__.py:241] Automatically detected platform cuda.
WARNING 09-15 22:12:43 [__init__.py:4006] Current vLLM config is not set.
QuantFP8 performance:
     hidden_size  batch_size  col_major group_shape  Torch (Compiled)          CUDA        Triton
0              1           1      False    (-1, -1)          3.426930     11.369804      0.000000
1             16           1      False    (-1, -1)          8.376470     12.388232      0.000000
2             64           1      False    (-1, -1)         13.814244     12.817638      0.000000
3            128           1      False    (-1, -1)         20.951711     12.964803      0.000000
4            256           1      False    (-1, -1)         34.481871     13.696710      0.000000
5            512           1      False    (-1, -1)         61.226926     15.780297      0.000000
6           1024           1      False    (-1, -1)        114.341924     20.364526      0.000000
7           2048           1      False    (-1, -1)        236.421476     38.074056      0.000000
8           4096           1      False    (-1, -1)        591.811001     78.746200      0.000000
9              1          16      False    (-1, -1)          8.341590     12.388724      0.000000
10            16          16      False    (-1, -1)         34.546096     13.682885      0.000000
11            64          16      False    (-1, -1)        114.357070     20.343274      0.000000
12           128          16      False    (-1, -1)        236.438847     38.080545      0.000000
13           256          16      False    (-1, -1)        592.441022     78.781415      0.000000
14           512          16      False    (-1, -1)       1251.093356    159.822960      0.000000
15          1024          16      False    (-1, -1)       2498.866217    303.482584      0.000000
16          2048          16      False    (-1, -1)       4987.429301    588.579871      0.000000
17          4096          16      False    (-1, -1)      10070.744038   1262.570325      0.000000
18             1          32      False    (-1, -1)         10.093143     12.620657      0.000000
19            16          32      False    (-1, -1)         61.220099     15.966509      0.000000
20            64          32      False    (-1, -1)        230.268398     38.071063      0.000000
21           128          32      False    (-1, -1)        595.449001     78.403324      0.000000
22           256          32      False    (-1, -1)       1250.521596    159.590636      0.000000
23           512          32      False    (-1, -1)       2498.742785    303.515816      0.000000
24          1024          32      False    (-1, -1)       4986.783981    588.453322      0.000000
25          2048          32      False    (-1, -1)      10069.344044   1262.711525      0.000000
26          4096          32      False    (-1, -1)      20116.783142   2505.393982      0.000000
27             1          64      False    (-1, -1)         13.798554     12.771872      0.000000
28            16          64      False    (-1, -1)        114.416193     20.301805      0.000000
29            64          64      False    (-1, -1)        600.231749     78.803340      0.000000
30           128          64      False    (-1, -1)       1250.677299    159.436134      0.000000
31           256          64      False    (-1, -1)       2498.580524    303.701900      0.000000
32           512          64      False    (-1, -1)       4987.276077    588.464477      0.000000
33          1024          64      False    (-1, -1)      10068.376064   1262.856483      0.000000
34          2048          64      False    (-1, -1)      20116.191864   2504.941940      0.000000
35          4096          64      False    (-1, -1)      40226.768494   4989.788055      0.000000
36             1         128      False    (-1, -1)         20.818095     12.903089      0.000000
37            16         128      False    (-1, -1)        229.883889     38.125303      0.000000
38            64         128      False    (-1, -1)       1249.669329    159.467705      0.000000
39           128         128      False    (-1, -1)       2498.089109    303.633235      0.000000
40           256         128      False    (-1, -1)       4986.624002    589.265448      0.000000
41           512         128      False    (-1, -1)      10069.543839   1262.829220      0.000000
42          1024         128      False    (-1, -1)      20133.200645   2505.151987      0.000000
43          2048         128      False    (-1, -1)      40217.935562   4989.408016      0.000000
44          4096         128      False    (-1, -1)      11261.568069   9964.223862      0.000000
45             1           1      False     (1, -1)          3.766076      4.882628      0.000000
46            16           1      False     (1, -1)          4.083627      5.128845      0.000000
47            64           1      False     (1, -1)          4.328088      5.381153      0.000000
48           128           1      False     (1, -1)          4.347911      5.476528      0.000000
49           256           1      False     (1, -1)          5.408176      6.300647      0.000000
50           512           1      False     (1, -1)          7.660445      8.727967      0.000000
51          1024           1      False     (1, -1)         12.177438     13.583418      0.000000
52          2048           1      False     (1, -1)         25.081955     27.920139      0.000000
53          4096           1      False     (1, -1)         48.859199     52.755853      0.000000
54             1          16      False     (1, -1)          4.068348      5.120795      0.000000
55            16          16      False     (1, -1)          5.421156      6.265843      0.000000
56            64          16      False     (1, -1)         12.077187     13.651969      0.000000
57           128          16      False     (1, -1)         25.072359     27.831579      0.000000
58           256          16      False     (1, -1)         48.880177     52.696390      0.000000
59           512          16      False     (1, -1)         96.035775    103.241584      0.000000
60          1024          16      False     (1, -1)        190.820650    203.414500      0.000000
61          2048          16      False     (1, -1)        379.653332    402.185479      0.000000
62          4096          16      False     (1, -1)        861.698444    905.443350      0.000000
63             1          32      False     (1, -1)          4.115692      5.171992      0.000000
64            16          32      False     (1, -1)          7.618667      8.701331      0.000000
65            64          32      False     (1, -1)         25.055467     27.966720      0.000000
66           128          32      False     (1, -1)         48.874609     52.983442      0.000000
67           256          32      False     (1, -1)         95.862828    102.968303      0.000000
68           512          32      False     (1, -1)        190.917082    203.455826      0.000000
69          1024          32      False     (1, -1)        379.155776    402.248168      0.000000
70          2048          32      False     (1, -1)        862.017852    905.406674      0.000000
71          4096          32      False     (1, -1)       1719.159346   1804.554621      0.000000
72             1          64      False     (1, -1)          4.341913      5.397688      0.000000
73            16          64      False     (1, -1)         12.359822     13.573305      0.000000
74            64          64      False     (1, -1)         48.923089     52.619177      0.000000
75           128          64      False     (1, -1)         95.945486    102.920899      0.000000
76           256          64      False     (1, -1)        190.856379    203.439842      0.000000
77           512          64      False     (1, -1)        379.234501    402.700074      0.000000
78          1024          64      False     (1, -1)        861.697857    905.281345      0.000000
79          2048          64      False     (1, -1)       1719.143354   1804.654598      0.000000
80          4096          64      False     (1, -1)       3433.789412   3603.221258      0.000000
81             1         128      False     (1, -1)          4.373913      5.472046      0.000000
82            16         128      False     (1, -1)         25.271191     28.087846      0.000000
83            64         128      False     (1, -1)         96.118000    103.077620      0.000000
84           128         128      False     (1, -1)        190.537217    203.164836      0.000000
85           256         128      False     (1, -1)        379.256473    402.043109      0.000000
86           512         128      False     (1, -1)        861.623397    905.344685      0.000000
87          1024         128      False     (1, -1)       1719.219721   1804.641326      0.000000
88          2048         128      False     (1, -1)       3433.952014   3603.064060      0.000000
89          4096         128      False     (1, -1)       7063.557307   7200.703939      0.000000
90             1           1       True     (1, 64)          3.211600      3.733024      3.428513
91            16           1       True     (1, 64)          4.969014      4.198962      4.060923
92            64           1       True     (1, 64)          5.522023      4.665202      6.168615
93           128           1       True     (1, 64)          5.908279      5.350404      8.560178
94           256           1       True     (1, 64)          6.949885      7.440291     13.747636
95           512           1       True     (1, 64)          9.361156     11.759880     24.297195
96          1024           1       True     (1, 64)         13.973953     20.104721     45.274666
97          2048           1       True     (1, 64)         28.041333     41.274185     91.696920
98          4096           1       True     (1, 64)         53.999811     84.167932    180.267856
99             1          16       True     (1, 64)          4.965954      4.109570      4.063820
100           16          16       True     (1, 64)          6.911812      7.442560     13.755685
101           64          16       True     (1, 64)         14.122729     20.143829     45.297506
102          128          16       True     (1, 64)         28.134139     41.233683     91.670069
103          256          16       True     (1, 64)         54.123671     84.339004    180.293040
104          512          16       True     (1, 64)        105.026447    165.446920    359.945609
105         1024          16       True     (1, 64)        206.923956    328.232527    718.948753
106         2048          16       True     (1, 64)        410.122892    652.679984   1436.993232
107         4096          16       True     (1, 64)        921.450655   1408.382924   2978.645325
108            1          32       True     (1, 64)          5.136558      4.307520      4.729818
109           16          32       True     (1, 64)          9.366667     11.745251     24.289091
110           64          32       True     (1, 64)         28.065860     41.290563     91.683862
111          128          32       True     (1, 64)         54.025489     84.247964    180.237582
112          256          32       True     (1, 64)        104.790097    165.361623    359.877812
113          512          32       True     (1, 64)        207.243956    328.214661    719.029956
114         1024          32       True     (1, 64)        410.181796    652.729575   1436.598191
115         2048          32       True     (1, 64)        921.548009   1406.723150   2978.714625
116         4096          32       True     (1, 64)       1836.225351   2808.262825   5953.957240
117            1          64       True     (1, 64)          5.513788      4.664583      6.184711
118           16          64       True     (1, 64)         14.285242     20.085849     45.279999
119           64          64       True     (1, 64)         54.077209     84.197567    180.335067
120          128          64       True     (1, 64)        104.792090    165.554717    359.763197
121          256          64       True     (1, 64)        206.836593    328.357077    719.011554
122          512          64       True     (1, 64)        410.108587    652.651183   1436.903367
123         1024          64       True     (1, 64)        921.527982   1407.067744   2979.717414
124         2048          64       True     (1, 64)       1836.198648   2808.749744   5954.090754
125         4096          64       True     (1, 64)       3671.117306   5611.925125  11905.551910
126            1         128       True     (1, 64)          5.897195      5.287045      8.572138
127           16         128       True     (1, 64)         28.203722     41.419590     91.858258
128           64         128       True     (1, 64)        105.017906    165.531660    359.770369
129          128         128       True     (1, 64)        206.584713    328.138399    719.408565
130          256         128       True     (1, 64)        410.141275    652.689584   1436.685562
131          512         128       True     (1, 64)        921.536644   1406.915156   2978.573322
132         1024         128       True     (1, 64)       1836.305380   2809.049198   5954.288165
133         2048         128       True     (1, 64)       3671.101252   5613.439878  11904.176235
134         4096         128       True     (1, 64)       6618.202845  11332.128048  23803.744316
135            1           1      False     (1, 64)          3.213053      3.737146      3.434413
136           16           1      False     (1, 64)          3.461263      4.013728      4.044279
137           64           1      False     (1, 64)          3.889438      4.445747      6.169670
138          128           1      False     (1, 64)          4.269689      5.071347      8.572800
139          256           1      False     (1, 64)          5.317511      6.974913     13.752533
140          512           1      False     (1, 64)          7.764909     10.916158     24.291023
141         1024           1      False     (1, 64)         12.655436     18.521191     45.276088
142         2048           1      False     (1, 64)         26.595378     38.248682     91.709508
143         4096           1      False     (1, 64)         52.948673     77.945717    180.369657
144            1          16      False     (1, 64)          3.455824      4.003043      4.050667
145           16          16      False     (1, 64)          5.327624      6.987025     13.766359
146           64          16      False     (1, 64)         12.519731     18.552234     45.310529
147          128          16      False     (1, 64)         26.606066     38.244530     91.662001
148          256          16      False     (1, 64)         53.125866     77.968000    180.293940
149          512          16      False     (1, 64)        104.398328    153.584878    359.562362
150         1024          16      False     (1, 64)        207.137364    304.513476    719.475570
151         2048          16      False     (1, 64)        412.094673    605.723023   1437.410501
152         4096          16      False     (1, 64)        927.354018   1313.723981   2978.557428
153            1          32      False     (1, 64)          3.510506      4.086418      4.725451
154           16          32      False     (1, 64)          7.760727     10.927646     24.304180
155           64          32      False     (1, 64)         26.632533     38.256469     91.702578
156          128          32      False     (1, 64)         52.968628     77.945068    180.234485
157          256          32      False     (1, 64)        104.139274    153.343621    359.916791
158          512          32      False     (1, 64)        207.299853    304.529234    718.966519
159         1024          32      False     (1, 64)        412.570318    606.558979   1436.708890
160         2048          32      False     (1, 64)        928.287983   1313.728988   2978.645325
161         4096          32      False     (1, 64)       1851.147970   2622.544050   5954.175949
162            1          64      False     (1, 64)          3.892622      4.447304      6.166435
163           16          64      False     (1, 64)         12.727652     18.519504     45.272533
164           64          64      False     (1, 64)         53.008703     77.996370    180.311727
165          128          64      False     (1, 64)        104.147862    153.311369    359.488591
166          256          64      False     (1, 64)        207.200004    304.538345    718.913184
167          512          64      False     (1, 64)        412.189007    605.737507   1437.422752
168         1024          64      False     (1, 64)        927.140633   1313.645005   2978.600025
169         2048          64      False     (1, 64)       1851.236025   2626.507998   5954.240163
170         4096          64      False     (1, 64)       3697.957357   5249.995947  11905.071735
171            1         128      False     (1, 64)          4.272533      5.077902      8.564085
172           16         128      False     (1, 64)         26.748978     38.458702     91.843600
173           64         128      False     (1, 64)        104.432002    153.509793    359.726836
174          128         128      False     (1, 64)        207.112196    304.335756    719.082656
175          256         128      False     (1, 64)        412.117342    605.756491   1437.017881
176          512         128      False     (1, 64)        927.374005   1313.645005   2978.434722
177         1024         128      False     (1, 64)       1851.334651   2622.045994   5953.989347
178         2048         128      False     (1, 64)       3697.727998   5239.295959  11905.488014
179         4096         128      False     (1, 64)       6553.968112  10476.888180  23802.656174
180            1           1       True    (1, 128)          3.365634      3.719100      3.439059
181           16           1       True    (1, 128)          5.059778      4.148897      3.741449
182           64           1       True    (1, 128)          5.484424      4.512350      4.907259
183          128           1       True    (1, 128)          5.940381      4.580592      6.132315
184          256           1       True    (1, 128)          7.030278      5.499121      8.849302
185          512           1       True    (1, 128)          9.476329      8.142480     14.475759
186         1024           1       True    (1, 128)         14.305333     13.133095     25.609788
187         2048           1       True    (1, 128)         28.382844     27.406244     52.442402
188         4096           1       True    (1, 128)         54.402834     57.215752    101.486654
189            1          16       True    (1, 128)          5.022139      4.175365      3.745104
190           16          16       True    (1, 128)          7.016952      5.494959      8.845768
191           64          16       True    (1, 128)         14.399247     13.163304     25.643482
192          128          16       True    (1, 128)         28.430682     27.331508     52.386313
193          256          16       True    (1, 128)         54.526667     57.230854    101.583038
194          512          16       True    (1, 128)        105.662829    110.530924    202.171235
195         1024          16       True    (1, 128)        207.908003    217.277191    403.646703
196         2048          16       True    (1, 128)        412.774816    429.587820    806.832671
197         4096          16       True    (1, 128)        925.827344    960.459129   1718.465328
198            1          32       True    (1, 128)          5.247812      4.234421      4.079264
199           16          32       True    (1, 128)          9.488572      8.133026     14.474409
200           64          32       True    (1, 128)         28.405647     27.332091     52.444952
201          128          32       True    (1, 128)         54.421459     57.225975    101.491279
202          256          32       True    (1, 128)        105.501268    110.233694    201.829080
203          512          32       True    (1, 128)        208.135159    217.147912    403.939267
204         1024          32       True    (1, 128)        412.490216    430.136681    806.446671
205         2048          32       True    (1, 128)        925.950646    960.419448   1717.906634
206         4096          32       True    (1, 128)       1846.174637   1913.911299   3432.864030
207            1          64       True    (1, 128)          5.441861      4.509423      4.901517
208           16          64       True    (1, 128)         14.414493     13.144289     25.597209
209           64          64       True    (1, 128)         54.458604     57.044986    101.550769
210          128          64       True    (1, 128)        105.469633    110.230186    201.990778
211          256          64       True    (1, 128)        208.057423    217.129421    403.732572
212          512          64       True    (1, 128)        412.847661    429.523136    806.899985
213         1024          64       True    (1, 128)        925.762018    960.399296   1718.470653
214         2048          64       True    (1, 128)       1845.648050   1914.319992   3434.925397
215         4096          64       True    (1, 128)       3684.111913   3823.945427   6868.719737
216            1         128       True    (1, 128)          5.953882      4.573895      6.137195
217           16         128       True    (1, 128)         28.571238     27.449536     52.597903
218           64         128       True    (1, 128)        105.599806    110.564384    202.170922
219          128         128       True    (1, 128)        207.901509    216.756619    403.827589
220          256         128       True    (1, 128)        412.155233    429.694259    806.952675
221          512         128       True    (1, 128)        925.811330    960.475134   1719.492038
222         1024         128       True    (1, 128)       1845.808029   1914.247253   3436.186790
223         2048         128       True    (1, 128)       3684.325377   3823.353577   6860.954603
224         4096         128       True    (1, 128)       6081.194878   7704.312086  13786.191940
225            1           1      False    (1, 128)          3.392000      3.709895      3.442704
226           16           1      False    (1, 128)          3.698105      4.007026      3.731692
227           64           1      False    (1, 128)          4.852315      4.270214      4.890133
228          128           1      False    (1, 128)          6.048719      4.355876      6.131131
229          256           1      False    (1, 128)          8.776000      5.217311      8.844045
230          512           1      False    (1, 128)         14.401244      7.618571     14.475587
231         1024           1      False    (1, 128)         25.544177     12.264484     25.611402
232         2048           1      False    (1, 128)         52.538547     25.970244     52.638459
233         4096           1      False    (1, 128)        101.296184     53.646130    101.386791
234            1          16      False    (1, 128)          3.695121      3.979560      3.732465
235           16          16      False    (1, 128)          8.780404      5.225337      8.847461
236           64          16      False    (1, 128)         25.539057     12.247330     25.607192
237          128          16      False    (1, 128)         52.550728     25.977990     52.626182
238          256          16      False    (1, 128)        101.299454     53.637059    101.377092
239          512          16      False    (1, 128)        201.981955    104.020303    201.960930
240         1024          16      False    (1, 128)        403.625800    204.264998    403.473309
241         2048          16      False    (1, 128)        806.029995    403.998991    806.042035
242         4096          16      False    (1, 128)       1719.342709    909.742673   1718.052069
243            1          32      False    (1, 128)          4.034489      4.054815      4.095822
244           16          32      False    (1, 128)         14.391551      7.630044     14.464363
245           64          32      False    (1, 128)         52.550401     25.971356     52.631454
246          128          32      False    (1, 128)        101.291949     53.691414    101.386547
247          256          32      False    (1, 128)        201.894139    104.077008    201.974207
248          512          32      False    (1, 128)        403.412585    204.231997    403.438373
249         1024          32      False    (1, 128)        806.162675    403.963751    806.244016
250         2048          32      False    (1, 128)       1717.853308    910.474658   1718.659957
251         4096          32      False    (1, 128)       3432.936033   1814.298630   3436.389287
252            1          64      False    (1, 128)          4.850989      4.291952      4.899131
253           16          64      False    (1, 128)         25.543540     12.272197     25.604315
254           64          64      False    (1, 128)        101.272729     53.772713    101.347733
255          128          64      False    (1, 128)        202.097691    104.084703    202.006140
256          256          64      False    (1, 128)        403.653223    204.219669    403.555928
257          512          64      False    (1, 128)        805.734038    404.245883    806.359331
258         1024          64      False    (1, 128)       1718.660037    909.840027   1719.415983
259         2048          64      False    (1, 128)       3434.594631   1814.373334   3436.511993
260         4096          64      False    (1, 128)       6868.687948   3621.567885   6868.816058
261            1         128      False    (1, 128)          6.057422      4.369341      6.140132
262           16         128      False    (1, 128)         52.693885     26.060790     52.766023
263           64         128      False    (1, 128)        202.184475    104.496089    202.191638
264          128         128      False    (1, 128)        403.610444    204.033166    403.611767
265          256         128      False    (1, 128)        806.454023    403.920972    806.513309
266          512         128      False    (1, 128)       1718.478680    910.659313   1719.750722
267         1024         128      False    (1, 128)       3435.623964   1813.621362   3432.991982
268         2048         128      False    (1, 128)       6869.232178   3621.701399   6869.861285
269         4096         128      False    (1, 128)       6085.402489   7239.562670  13785.264015
/home/mgoin/code/vllm/benchmarks/kernels/bench_per_token_quant_fp8.py:176: FutureWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.        
  result = df.groupby(groupby_cols).apply(geo_speedup).reset_index()
Speedup over Torch (Compiled)
 col_major group_shape     CUDA   Triton
     False    (-1, -1) 4.535492      inf
     False     (1, -1) 0.901057      inf
     False     (1, 64) 0.719555 0.347461
     False    (1, 128) 1.713228 0.978972
      True     (1, 64) 0.733850 0.379045
      True    (1, 128) 1.030362 0.622426

Comment on lines +36 to +37
# recompile for different shapes
fwd = torch.compile(fn, fullgraph=True, dynamic=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we compiling with dynamic=True? I don't think we should be targeting shape specialization since we won't use that in practice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice we specialize on all shapes except the first dim (num_tokens). The with_dyn_arg marks that shape as dynamic to fully simulate vLLM usage 👍

Comment on lines 59 to 96
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
hidden_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
batch_sizes = [1, 16, 32, 64, 128]
group_shapes = [
GroupShape.PER_TENSOR,
GroupShape.PER_TOKEN,
GroupShape(1, 64),
GroupShape(1, 128),
]
column_major_scales = [True, False]

config_gen = itertools.product(
group_shapes,
column_major_scales,
batch_sizes,
hidden_sizes,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make some of these args? It takes a really long time by default to run all of this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can that be a follow-up it requires reworking the structure a lot (because currently this is passed to the function annotation).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. Now you should be able to use something like this:

python3 benchmarks/kernels/bench_per_token_quant_fp8.py --hidden-sizes 1024 2048 4096 --batch-sizes 32 --group-sizes 128 --no-column-major

assert not static, "Group quantization only supports dynamic mode"
self.group_size = group_shape.col
else:
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an assert that column_major_scales is False if non group?


x_quant = x_quant.view(-1, padded_dim)
if padded_dim != hidden_dim:
x_quant = x_quant[..., :hidden_dim]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make sure this is contiguous after stripping padding?

Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
@tahsintunan
Copy link
Contributor Author

tahsintunan commented Sep 17, 2025

Ran this on 5090
python3 benchmarks/kernels/bench_per_token_quant_fp8.py --group-sizes 128
INFO 09-17 06:45:03 [__init__.py:241] Automatically detected platform cuda.
Running 90 configurations:
  Hidden sizes: [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
  Batch sizes: [1, 16, 32, 64, 128]
  Group shapes: ['GroupShape(row=1, col=128)']
  Column major scales: [True, False]

WARNING 09-17 06:45:03 [__init__.py:4006] Current vLLM config is not set.
QuantFP8 performance:
    hidden_size  batch_size  col_major group_shape  Torch (Compiled)         CUDA        Triton
0             1           1       True    (1, 128)          1.529624     1.799951      1.588296
1            16           1       True    (1, 128)          2.664061     1.896050      1.825615
2            64           1       True    (1, 128)          3.151059     2.031393      2.344483
3           128           1       True    (1, 128)          3.466842     2.615653      3.587483
4           256           1       True    (1, 128)          3.930947     3.003212      5.833511
5           512           1       True    (1, 128)          5.440418     4.756721     10.256148
6          1024           1       True    (1, 128)          7.227117     7.535516     19.673112
7          2048           1       True    (1, 128)         11.536611    13.036659     37.944444
8          4096           1       True    (1, 128)         20.485161    25.902077     74.268145
9             1          16       True    (1, 128)          3.010329     1.906905      1.607784
10           16          16       True    (1, 128)          4.271590     3.228557      5.841455
11           64          16       True    (1, 128)          7.228928     7.306826     19.668745
12          128          16       True    (1, 128)         11.863843    14.695586     37.683108
13          256          16       True    (1, 128)         20.506458    26.008614     74.272488
14          512          16       True    (1, 128)        106.840966   108.374787    190.798741
15         1024          16       True    (1, 128)        297.140576   300.404776    419.444882
16         2048          16       True    (1, 128)        605.138987   629.768256    824.126005
17         4096          16       True    (1, 128)       1288.587987  1209.307015   1671.729435
18            1          32       True    (1, 128)          2.705019     2.173277      2.066595
19           16          32       True    (1, 128)          5.098078     4.536950     10.480000
20           64          32       True    (1, 128)         11.871263    14.055407     37.717534
21          128          32       True    (1, 128)         20.197164    26.115301     74.276211
22          256          32       True    (1, 128)        106.793714   114.143557    190.792877
23          512          32       True    (1, 128)        314.867988   300.364841    421.562779
24         1024          32       True    (1, 128)        604.380145   635.277083    824.123303
25         2048          32       True    (1, 128)       1279.839993  1209.230959   1788.946845
26         4096          32       True    (1, 128)       2415.948033  2545.812011   3358.086395
27            1          64       True    (1, 128)          2.813626     2.037592      2.347145
28           16          64       True    (1, 128)          7.571660     7.540632     19.476136
29           64          64       True    (1, 128)         20.171550    26.119292     78.363492
30          128          64       True    (1, 128)        106.876798   108.313243    190.812635
31          256          64       True    (1, 128)        297.047749   300.422441    401.257457
32          512          64       True    (1, 128)        604.244839   604.824240    824.258010
33         1024          64       True    (1, 128)       1207.914948  1209.682047   1672.298171
34         2048          64       True    (1, 128)       2413.725972  2420.926094   3358.582306
35         4096          64       True    (1, 128)       4844.539881  4837.064028   6737.552166
36            1         128       True    (1, 128)          3.143795     2.404182      3.609469
37           16         128       True    (1, 128)         11.860484    13.418697     37.704471
38           64         128       True    (1, 128)        106.774714   115.208321    190.801904
39          128         128       True    (1, 128)        311.571240   300.358064    420.077051
40          256         128       True    (1, 128)        604.894012   639.263988    823.673328
41          512         128       True    (1, 128)       1272.551000  1208.703995   1806.311347
42         1024         128       True    (1, 128)       2412.207961  2540.745974   3359.260750
43         2048         128       True    (1, 128)       5110.131979  4835.279942   6736.232042
44         4096         128       True    (1, 128)       9736.680031  9677.936077  13501.760006
45            1           1      False    (1, 128)          1.518752     2.027383      1.807600
46           16           1      False    (1, 128)          1.568000     1.878482      1.840209
47           64           1      False    (1, 128)          2.546120     2.233692      2.346144
48          128           1      False    (1, 128)          3.369446     2.579003      3.604178
49          256           1      False    (1, 128)          5.787374     3.003303      5.608000
50          512           1      False    (1, 128)         10.339926     4.640000     10.475840
51         1024           1      False    (1, 128)         19.304242     7.163867     19.701369
52         2048           1      False    (1, 128)         37.804900    12.761695     37.726294
53         4096           1      False    (1, 128)         74.344102    25.194550     74.292790
54            1          16      False    (1, 128)          1.804048     1.877930      1.612863
55           16          16      False    (1, 128)          5.786667     3.229279      5.828779
56           64          16      False    (1, 128)         19.218204     7.163238     19.695894
57          128          16      False    (1, 128)         37.799704    12.757911     37.728331
58          256          16      False    (1, 128)         74.145927    25.240405     74.302148
59          512          16      False    (1, 128)        197.018080   108.189256    198.538299
60         1024          16      False    (1, 128)        393.509769   316.088734    401.330286
61         2048          16      False    (1, 128)        852.856000   605.046988    866.718650
62         4096          16      False    (1, 128)       1640.645345  1279.440999   1671.720765
63            1          32      False    (1, 128)          2.052503     1.914565      1.848519
64           16          32      False    (1, 128)         10.340355     4.645974     10.265247
65           64          32      False    (1, 128)         37.563259    12.756656     37.943685
66          128          32      False    (1, 128)         74.167902    25.657238     82.329922
67          256          32      False    (1, 128)        187.125431   108.236800    190.808001
68          512          32      False    (1, 128)        393.517437   299.987157    401.220575
69         1024          32      False    (1, 128)        808.468592   604.785457    824.419022
70         2048          32      False    (1, 128)       1641.431967  1209.881008   1672.231241
71         4096          32      False    (1, 128)       3296.864033  2419.297934   3359.558296
72            1          64      False    (1, 128)          2.325557     2.234767      2.346144
73           16          64      False    (1, 128)         19.316558     7.394712     19.705880
74           64          64      False    (1, 128)         81.798815    25.388851     74.662012
75          128          64      False    (1, 128)        187.248918   108.202608    190.831530
76          256          64      False    (1, 128)        393.480949   299.760719    401.333225
77          512          64      False    (1, 128)        808.546248   604.882002    824.189305
78         1024          64      False    (1, 128)       1641.737302  1210.404992   1673.396371
79         2048          64      False    (1, 128)       3299.048106  2420.530081   3360.531235
80         4096          64      False    (1, 128)       6614.485105  4843.434652   6733.455896
81            1         128      False    (1, 128)          3.374754     2.351524      3.611189
82           16         128      False    (1, 128)         37.774400    14.217525     37.755050
83           64         128      False    (1, 128)        187.869921   108.183866    190.848794
84          128         128      False    (1, 128)        393.634892   299.814542    401.497471
85          256         128      False    (1, 128)        808.779319   604.860595    825.055361
86          512         128      False    (1, 128)       1641.189337  1209.796011   1673.211618
87         1024         128      False    (1, 128)       3298.245271  2419.300079   3360.537529
88         2048         128      False    (1, 128)       6614.058812  4837.924004   6735.304117
89         4096         128      False    (1, 128)      10256.760120  9678.048134  13503.567696
Speedup over Torch (Compiled)
 col_major group_shape     CUDA   Triton
     False    (1, 128) 1.631587 0.980109
      True    (1, 128) 1.026367 0.622223

@simon-mo simon-mo merged commit cef3210 into vllm-project:main Sep 17, 2025
44 of 48 checks passed
@mgoin
Copy link
Member

mgoin commented Sep 17, 2025

Merging for now since it is valid and isn't on by default, thanks for the nice work!

@tahsintunan tahsintunan deleted the quantfp8-group branch September 17, 2025 19:22
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…roject#24342)

Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…roject#24342)

Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: charlifu <charlifu@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants