Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def update_weight_global_scale(module: Module):
should_calculate_gparam=True,
should_calculate_qparams=False,
)
module.weight_observer.reset()


def update_weight_zp_scale(module: Module):
Expand Down Expand Up @@ -199,6 +198,10 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
calculate_gparam = True

# (..., 1, hidden_dim)
# the second to last dim indicates that activations have one output channel
value = value.flatten(0, -1).unsqueeze(-2)

call_observer(
module=module,
base_name=base_name,
Expand Down
84 changes: 41 additions & 43 deletions src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from math import ceil
from typing import Any, Iterable, Optional, Tuple, Union

import torch
Expand All @@ -8,7 +7,7 @@
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.utils import is_fp4
from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv
from compressed_tensors.registry.registry import RegistryMixin
from loguru import logger
from torch import FloatTensor, IntTensor, Tensor
Expand Down Expand Up @@ -127,83 +126,77 @@ def get_qparams(
:param global_scale: optional scale to further scale local quantization scales
:return: tuple of scale and zero point based on last observed value
"""
if observed is not None:
group_size = self.quantization_args.group_size
strategy = self.quantization_args.strategy

if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
if observed is not None:
if strategy == QuantizationStrategy.TENSOR:
# re-calculate scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)

elif self.quantization_args.strategy in (
elif strategy in (
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.GROUP,
):
rows = observed.shape[0]
columns = observed.shape[1]
num_groups = int(ceil(columns / group_size))
if num_groups * group_size != columns:
logger.bind(log_once=True).warning(
"Attempting to quantize a module weight whose columns "
f"({columns}) are not divisible by group_size ({group_size}). "
"This scheme is not supported by vLLM, please consider "
"adjusting the group_size for modules with this number of "
"columns",
)
# should be identical implementation to first half of
# `_process_quantization`

self._scale = torch.empty(
(rows, num_groups), dtype=observed.dtype, device=observed.device
)
# get shapes
assert observed.ndim >= 2
rows, columns = observed.shape[-2:]
group_size = self.quantization_args.group_size
num_groups = strategy_cdiv(columns, group_size, strategy)

# FP4: cast zp type
if is_fp4(quantization_args=self.quantization_args):
zp_dtype = FP8_E4M3_DATA.dtype
else:
zp_dtype = self.quantization_args.pytorch_dtype()

# allocate qparams
self._scale = torch.empty(
(rows, num_groups), dtype=observed.dtype, device=observed.device
)
self._zero_point = torch.empty(
(rows, num_groups), dtype=zp_dtype, device=observed.device
)

# support column-order (default) quantization as well as other orderings
# such as activation ordering. Below checks if g_idx has initialized
is_column_order = g_idx is None or -1 in g_idx
if is_column_order:
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
else:
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
group_sizes = group_sizes[torch.argsort(group_indices)]

observed = observed.index_select(-1, g_idx)
# permute groups
if g_idx is not None:
perm = torch.argsort(g_idx)
observed = observed.index_select(-1, perm)

# TODO: experiment with vectorizing for loop for performance
# all reduce all dims except the second to last one
end = 0
for group_index, group_count in enumerate(group_sizes):
for group_index in range(num_groups):
start = end
end = start + group_count
end = start + group_size
scale, zero_point = self.get_qparams_along_dim(
observed[:, start:end],
0,
observed[..., start:end],
dim=-2,
tensor_id=group_index,
global_scale=global_scale,
)

self._scale[:, group_index] = scale.squeeze(1)
self._zero_point[:, group_index] = zero_point.squeeze(1)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
elif strategy == QuantizationStrategy.CHANNEL:
# all reduce all dims except the second to last one
self._scale, self._zero_point = self.get_qparams_along_dim(observed, -2)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
elif strategy == QuantizationStrategy.TOKEN:
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
# should be batch, token
self._scale, self._zero_point = self.get_qparams_along_dim(
observed,
dim={0, 1},
)

elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
elif strategy == QuantizationStrategy.BLOCK:
# Block-wise quantization: one scale/zero_point per block of shape
# [block_rows, block_cols]
rows, cols = observed.shape[:2]
rows, cols = observed.shape[-2:]
bs = self.quantization_args.block_structure
if not (
isinstance(bs, (list, tuple))
Expand All @@ -215,8 +208,8 @@ def get_qparams(
f"Must be a list of two ints [rows, cols]."
)
block_rows, block_cols = bs
num_br = int(ceil(rows / block_rows))
num_bc = int(ceil(cols / block_cols))
num_br = strategy_cdiv(rows, block_rows, strategy)
num_bc = strategy_cdiv(cols, block_cols, strategy)

# allocate per-block scale and zero_point
self._scale = torch.empty(
Expand Down Expand Up @@ -255,15 +248,20 @@ def get_qparams(

def get_qparams_along_dim(
self,
observed,
observed: torch.Tensor,
dim: Union[int, Iterable[int]],
tensor_id: Optional[Any] = None,
global_scale: Optional[Tensor] = None,
):
# cast to set
if isinstance(dim, int):
dim = [dim]
dim = set(dim)

# convert negative dims
dim = [d if d >= 0 else observed.ndim + d for d in dim]

# reduce all dimensions except the the one passed as argument to this function
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
return self.calculate_qparams(
observed,
Expand Down
42 changes: 40 additions & 2 deletions tests/llmcompressor/modifiers/calibration/test_observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,46 @@ def test_static_weight_quantization(
0.06,
),
# channel is not supported, but is in principle equivalent to token/tensor
# group is not yet supported
# tensor_group is not yet supported
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="group",
group_size=3,
observer="minmax",
),
{
"default": torch.tensor([[0]]),
1: torch.tensor([[3]]),
},
{
"default": torch.tensor([[2]]),
1: torch.tensor([[5]]),
},
torch.tensor([[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875]]),
0.04,
),
(
QuantizationArgs(
num_bits=4,
type="float", # tensor group requires FP4
symmetric=True,
strategy="tensor_group",
group_size=3,
observer="minmax",
),
{
"default": torch.tensor([[0]]),
1: torch.tensor([[3]]),
},
{
"default": torch.tensor([[2]]),
1: torch.tensor([[5]]),
},
torch.tensor([[0.0000, 0.9766, 1.9531, 3.3125, 3.3125, 4.9688]]),
0.1,
),
# block is not supported, but is in principle similar to group
],
)
Expand Down