Skip to content

Commit cfabc13

Browse files
authored
Update method names to support intx and floatx changes (#775)
1 parent f538027 commit cfabc13

File tree

13 files changed

+54
-55
lines changed

13 files changed

+54
-55
lines changed

docs/source/api_ref_dtypes.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ torchao.dtypes
1111
:nosignatures:
1212

1313
to_nf4
14-
to_affine_quantized
14+
to_affine_quantized_intx
15+
to_affine_quantized_floatx
16+
to_affine_quantized_intx_static
1517
AffineQuantizedTensor
1618

1719
..

test/dtypes/test_affine_quantized.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
int8_dynamic_activation_int8_semi_sparse_weight,
1111
float8_weight_only,
1212
)
13-
from torchao.dtypes import (
14-
to_affine_quantized,
15-
)
1613
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1714

1815
import torch

test/hqq/test_hqq_affine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
import torch
33
from torchao.dtypes.affine_quantized_tensor import (
4-
to_affine_quantized,
4+
to_affine_quantized_intx,
55
ZeroPointDomain,
66
PlainAQTLayout,
77
PlainLayoutType,
@@ -49,7 +49,7 @@ def _eval_hqq(nbits, layout_type):
4949
if isinstance(layout_type, TensorCoreTiledLayoutType):
5050
target_dtype = torch.uint8 if TORCH_VERSION_AT_LEAST_2_5 else torch.int32
5151

52-
q_tensor_hqq = to_affine_quantized(
52+
q_tensor_hqq = to_affine_quantized_intx(
5353
input_float=W,
5454
mapping_type=mapping_type,
5555
block_size=block_size,

torchao/dtypes/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from .uint4 import UInt4Tensor
44
from .affine_quantized_tensor import (
55
AffineQuantizedTensor,
6-
to_affine_quantized,
7-
to_affine_quantized_static,
6+
to_affine_quantized_intx,
7+
to_affine_quantized_intx_static,
88
to_affine_quantized_floatx,
99
LayoutType,
1010
PlainLayoutType,
@@ -17,8 +17,8 @@
1717
"to_nf4",
1818
"UInt4Tensor"
1919
"AffineQuantizedTensor",
20-
"to_affine_quantized",
21-
"to_affine_quantized_static",
20+
"to_affine_quantized_intx",
21+
"to_affine_quantized_intx_static",
2222
"to_affine_quantized_floatx",
2323
"LayoutType",
2424
"PlainLayoutType",

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def __tensor_unflatten__(
187187
)
188188

189189
@classmethod
190-
def from_float(
190+
def from_hp_to_intx(
191191
cls,
192192
input_float: torch.Tensor,
193193
mapping_type: MappingType,
@@ -213,16 +213,16 @@ def from_float(
213213
group_size = max(block_size)
214214
compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype
215215
device = input_float.device
216-
int_data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
217-
int_data = int_data.to(target_dtype)
216+
data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
217+
data = data.to(target_dtype)
218218
else:
219219
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
220-
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
220+
data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
221221
# Note: output will be uint8 tensor for sub byte tensors for now
222222

223-
int_data = layout_type.post_process(int_data)
223+
data = layout_type.post_process(data)
224224
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
225-
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
225+
layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type)
226226
return cls(
227227
layout_tensor,
228228
block_size,
@@ -234,7 +234,7 @@ def from_float(
234234
)
235235

236236
@classmethod
237-
def from_float_static(
237+
def from_hp_to_intx_static(
238238
cls,
239239
input_float: torch.Tensor,
240240
scale: torch.Tensor,
@@ -266,15 +266,15 @@ def from_float_static(
266266
)
267267

268268
@classmethod
269-
def from_float_to_floatx(
269+
def from_hp_to_floatx(
270270
cls,
271271
input_float: torch.Tensor,
272272
block_size: Tuple[int, ...],
273273
target_dtype: torch.dtype = torch.float8_e4m3fn,
274274
layout_type: LayoutType = PlainLayoutType(),
275275
):
276276
if target_dtype in FP8_TYPES:
277-
cls.from_float(
277+
return cls.from_hp_to_intx(
278278
input_float=input_float,
279279
mapping_type=MappingType.SYMMETRIC,
280280
block_size=block_size,
@@ -1004,9 +1004,9 @@ def _(func, types, args, kwargs):
10041004
)
10051005
return return_and_correct_aliasing(func, args, kwargs, new)
10061006

1007-
to_affine_quantized = AffineQuantizedTensor.from_float
1008-
to_affine_quantized_static = AffineQuantizedTensor.from_float_static
1009-
to_affine_quantized_floatx = AffineQuantizedTensor.from_float_to_floatx
1007+
to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
1008+
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
1009+
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
10101010

10111011
if TORCH_VERSION_AT_LEAST_2_5:
10121012
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`

torchao/prototype/hqq/example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torchao.prototype.hqq.core import HQQQuantizer
33
from torchao.dtypes.affine_quantized_tensor import (
4-
to_affine_quantized,
4+
to_affine_quantized_intx,
55
ZeroPointDomain,
66
PlainAQTLayout,
77
PlainLayoutType,
@@ -38,7 +38,7 @@
3838

3939
for nbits in list(range(2, 9))[::-1]:
4040
print('------------------------------------------------------------------------------')
41-
q_tensor_default = to_affine_quantized(
41+
q_tensor_default = to_affine_quantized_intx(
4242
input_float=W,
4343
mapping_type=mapping_type,
4444
block_size=block_size,
@@ -57,7 +57,7 @@
5757
# nbits 4 | Default Dot product error 0.005926903802901506
5858

5959

60-
q_tensor_hqq = to_affine_quantized(
60+
q_tensor_hqq = to_affine_quantized_intx(
6161
input_float=W,
6262
mapping_type=mapping_type,
6363
block_size=block_size,
@@ -99,7 +99,7 @@
9999
# nbits 4 | Default Dot product error 0.0015244047390297055
100100

101101

102-
q_tensor_hqq = to_affine_quantized(
102+
q_tensor_hqq = to_affine_quantized_intx(
103103
input_float=W,
104104
mapping_type=mapping_type,
105105
block_size=block_size,

torchao/quantization/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ as an example:
8282
```python
8383
import torch
8484
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
85-
from torchao.dtypes import to_affine_quantized
85+
from torchao.dtypes import to_affine_quantized_intx
8686
import copy
8787
from torchao.quantization.quant_api import (
8888
quantize_,
@@ -142,9 +142,9 @@ speedup: 2.2715200981216173
142142

143143
What we do underlying the APIs are roughly the following:
144144
```
145-
from torchao.dtypes import to_affine_quantized
145+
from torchao.dtypes import to_affine_quantized_intx
146146
def int8wo_quant(weight):
147-
return to_affine_quantized(weight, MappingType.SYMMETRIC, (1, weight.shape[1]), torch.int8, eps=torch.finfo(torch.float32).eps, zero_point_dtype=torch.int64)
147+
return to_affine_quantized_intx(weight, MappingType.SYMMETRIC, (1, weight.shape[1]), torch.int8, eps=torch.finfo(torch.float32).eps, zero_point_dtype=torch.int64)
148148
149149
for n, m in model.named_modules():
150150
if isinstance(m, torch.nn.Linear):

torchao/quantization/autoquant.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def from_float(cls, weight):
284284
# return weight
285285

286286
# avoid circular dep
287-
from torchao.dtypes import to_affine_quantized
287+
from torchao.dtypes import to_affine_quantized_intx
288288
# weight settings
289289
mapping_type = MappingType.SYMMETRIC
290290
def get_weight_block_size(x):
@@ -306,10 +306,10 @@ def get_per_token_block_size(x):
306306
input_quant_min = -127
307307
input_quant_max = 127
308308
layout_type = PlainLayoutType()
309-
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
309+
input_quant_func = lambda x: to_affine_quantized_intx(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
310310

311311
block_size = get_weight_block_size(weight)
312-
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
312+
weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
313313
weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
314314
return weight
315315

@@ -371,7 +371,7 @@ def from_float(cls, weight):
371371
eps = torch.finfo(torch.float32).eps
372372
zero_point_dtype = torch.int64
373373
block_size = (1, weight.shape[1])
374-
return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
374+
return super(AQWeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
375375

376376

377377
class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin):

torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def intN_weight_only(group_size=32, n=8, symmetric=False):
2121
# for asymmetric quantization
2222
def apply_intN_weight_only_quant_asym(weight):
2323
# avoid circular dependency
24-
from torchao.dtypes import to_affine_quantized
24+
from torchao.dtypes import to_affine_quantized_intx
2525
mapping_type = MappingType.ASYMMETRIC
2626
block_size = (1, group_size)
2727
target_dtype = torch.uint8
@@ -31,20 +31,20 @@ def apply_intN_weight_only_quant_asym(weight):
3131
preserve_zero = True
3232
zero_point_dtype = torch.int64
3333
zero_point_domain = ZeroPointDomain.INT
34-
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)
34+
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)
3535

3636
# for symmetric quantization
3737
def apply_intN_weight_only_quant_sym(weight):
3838
# avoid circular dependency
39-
from torchao.dtypes import to_affine_quantized
39+
from torchao.dtypes import to_affine_quantized_intx
4040
mapping_type = MappingType.SYMMETRIC
4141
block_size = (1, group_size)
4242
target_dtype = torch.int8
4343
quant_min = -2**(n-1)
4444
quant_max = 2**(n-1)-1
4545
eps = 1e-6
4646
zero_point_dtype = torch.int64
47-
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
47+
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
4848

4949
try:
5050
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"

torchao/quantization/prototype/qat/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32):
4949
quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32))
5050
"""
5151
# avoid circular dep
52-
from torchao.dtypes import to_affine_quantized
52+
from torchao.dtypes import to_affine_quantized_intx
5353

5454
def _apply_weight_fake_quant(weight: torch.Tensor):
5555
mapping_type = MappingType.SYMMETRIC

torchao/quantization/quant_api.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from torchao.dtypes.uintx.Uintx import UintxLayoutType
2525
from torchao.dtypes import (
26-
to_affine_quantized,
26+
to_affine_quantized_intx,
2727
TensorCoreTiledLayoutType,
2828
PlainLayoutType,
2929
AffineQuantizedTensor,
@@ -323,11 +323,11 @@ def quantize_(
323323
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
324324
# on weight
325325
326-
from torchao.dtypes import to_affine_quantized
326+
from torchao.dtypes import to_affine_quantized_intx
327327
328328
# weight only uint4 asymmetric groupwise quantization
329329
groupsize = 32
330-
apply_weight_quant = lambda x: to_affine_quantized(
330+
apply_weight_quant = lambda x: to_affine_quantized_intx(
331331
x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
332332
zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")
333333
@@ -356,7 +356,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
356356
def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
357357
mapping_type = MappingType.ASYMMETRIC
358358
target_dtype = torch.int8
359-
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)
359+
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype)
360360

361361
def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
362362
if weight.shape[-1] % group_size != 0:
@@ -373,7 +373,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
373373
# input settings
374374
input_quant_func = _int8_asymm_per_token_quant
375375

376-
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
376+
weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
377377
weight = to_linear_activation_quantized(weight, input_quant_func)
378378
return weight
379379

@@ -424,7 +424,7 @@ def apply_int4_weight_only_quant(weight, use_hqq=False):
424424
preserve_zero = False
425425
zero_point_dtype = torch.bfloat16
426426
zero_point_domain = ZeroPointDomain.FLOAT
427-
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)
427+
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)
428428

429429
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
430430

@@ -439,7 +439,7 @@ def apply_int8wo_quant(weight):
439439
eps = torch.finfo(torch.float32).eps
440440
zero_point_dtype = torch.int64
441441
block_size = (1, weight.shape[1])
442-
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
442+
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
443443

444444
return _get_linear_subclass_inserter(apply_int8wo_quant)
445445

@@ -449,7 +449,7 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
449449
eps = 1e-5
450450
quant_min = -127
451451
quant_max = 127
452-
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
452+
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
453453

454454

455455
def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):
@@ -475,7 +475,7 @@ def get_weight_block_size(x):
475475
input_quant_func = _int8_symm_per_token_reduced_range_quant
476476

477477
block_size = get_weight_block_size(weight)
478-
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
478+
weight = to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
479479
weight = to_linear_activation_quantized(weight, input_quant_func)
480480
return weight
481481

@@ -527,7 +527,7 @@ def apply_uintx_weight_only_quant(weight):
527527
zero_point_dtype = torch.int32
528528
zero_point_domain = ZeroPointDomain.INT
529529

530-
return to_affine_quantized(
530+
return to_affine_quantized_intx(
531531
weight, mapping_type, block_size, dtype,
532532
eps=eps, zero_point_dtype=zero_point_dtype,
533533
zero_point_domain=zero_point_domain,

tutorials/calibration_flow/gptq_like.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import gc
3434
from typing import Tuple, Dict, Any
3535
from torchao.quantization.utils import compute_error
36-
from torchao.dtypes import to_affine_quantized_static
36+
from torchao.dtypes import to_affine_quantized_intx_static
3737
from torchao.quantization import quantize_
3838
from torchao.quantization import to_linear_activation_quantized
3939
from torchao.quantization import LinearActivationQuantizedTensor
@@ -229,7 +229,7 @@ def _apply_activation_static_quant(observed_linear):
229229

230230
# activation quantization
231231
act_scale, act_zero_point = observed_linear.input_scale, observed_linear.input_zp
232-
input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype)
232+
input_quant_func = lambda x: to_affine_quantized_intx_static(x, act_scale, act_zero_point, x.shape, target_dtype)
233233
observed_linear.weight = torch.nn.Parameter(to_linear_activation_quantized(observed_linear.weight, input_quant_func), requires_grad=False)
234234

235235
del observed_linear.input_scale

0 commit comments

Comments
 (0)