-
Notifications
You must be signed in to change notification settings - Fork 383
Open
Description
import torch
from torchao.quantization import IntxWeightOnlyConfig, quantize_
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.granularity import PerGroup
class ToyLinearModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
model = ToyLinearModel(32, 32, 32).eval()
# Optional: compile model for faster inference and generation
# model = torch.compile(model, mode="max-autotune", fullgraph=True)
# model_bf16 = copy.deepcopy(model)
config = IntxWeightOnlyConfig(torch.int4, PerGroup(1), mapping_type=MappingType.ASYMMETRIC)
quantize_(model, config)
inp = torch.ones(2, 32)
model(inp)$ python a.py
Traceback (most recent call last):
File "/home/wzy/Desktop/ao/a.py", line 25, in <module>
quantize_(model, config)
~~~~~~~~~^^^^^^^^^^^^^^^
File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 498, in quantize_
_replace_with_custom_fn_if_matches_filter(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
model,
^^^^^^
...<3 lines>...
extra_args=(config,),
^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 214, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
child,
...<4 lines>...
extra_args,
)
File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 209, in _replace_with_custom_fn_if_matches_filter
model = replacement_fn(model, *extra_args)
File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 2375, in _intx_weight_only_transform
new_weight = _intx_weight_only_quantize_tensor(
module.weight,
...<2 lines>...
custom_zero_point=custom_zero_point,
)
File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 2320, in _intx_weight_only_quantize_tensor
new_weight = IntxUnpackedToInt8Tensor.from_hp(
weight,
...<5 lines>...
intx_choose_qparams_algorithm=intx_choose_qparams_algorithm,
)
File "/home/wzy/Desktop/ao/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py", line 233, in from_hp
qdata = quantize_affine(
hp_tensor,
...<5 lines>...
quant_max=qmax,
)
File "/usr/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 357, in quantize_affine
return _quantize_affine(
input,
...<5 lines>...
quant_max,
)
File "/usr/lib/python3.13/site-packages/torch/_ops.py", line 1158, in __call__
return self._op(*args, **(kwargs or {}))
~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 403, in _quantize_affine
return _quantize_affine_no_dtype_cast(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
input,
^^^^^^
...<4 lines>...
quant_max,
^^^^^^^^^^
).to(output_dtype)
^
File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 460, in _quantize_affine_no_dtype_cast
scale = scale.view(shape_after_reduction)
RuntimeError: shape '[32, 32]' is invalid for input of size 1Other PerGroup(x) can work.
Metadata
Metadata
Assignees
Labels
No labels